rust_rule_engine/backward/
optimizer.rs

1//! Query optimization for backward chaining
2//!
3//! This module implements query optimization techniques to improve performance:
4//! - Goal reordering based on selectivity
5//! - Index selection for O(1) lookups
6//! - Join ordering optimization
7//! - Predicate pushdown
8//! - Memoization of intermediate results
9//!
10//! # Examples
11//!
12//! ```rust,ignore
13//! use rust_rule_engine::backward::optimizer::QueryOptimizer;
14//!
15//! let mut optimizer = QueryOptimizer::new();
16//!
17//! // Before: item(?x) AND expensive(?x) AND in_stock(?x)
18//! // After:  in_stock(?x) AND expensive(?x) AND item(?x)
19//! //         (Evaluates most selective first)
20//!
21//! let optimized = optimizer.optimize_query(query)?;
22//! ```
23
24use super::goal::Goal;
25use std::collections::HashMap;
26
27/// Query optimizer for backward chaining
28#[derive(Debug, Clone)]
29pub struct QueryOptimizer {
30    /// Selectivity estimates for predicates
31    selectivity_map: HashMap<String, f64>,
32
33    /// Whether to enable goal reordering
34    enable_reordering: bool,
35
36    /// Whether to enable index selection
37    enable_index_selection: bool,
38
39    /// Whether to enable memoization
40    enable_memoization: bool,
41
42    /// Statistics for optimization
43    stats: OptimizationStats,
44}
45
46impl QueryOptimizer {
47    /// Create a new query optimizer
48    pub fn new() -> Self {
49        Self {
50            selectivity_map: HashMap::new(),
51            enable_reordering: true,
52            enable_index_selection: true,
53            enable_memoization: true,
54            stats: OptimizationStats::new(),
55        }
56    }
57
58    /// Create optimizer with custom configuration
59    pub fn with_config(config: OptimizerConfig) -> Self {
60        Self {
61            selectivity_map: HashMap::new(),
62            enable_reordering: config.enable_reordering,
63            enable_index_selection: config.enable_index_selection,
64            enable_memoization: config.enable_memoization,
65            stats: OptimizationStats::new(),
66        }
67    }
68
69    /// Optimize a list of goals
70    ///
71    /// Returns a reordered list of goals optimized for evaluation
72    pub fn optimize_goals(&mut self, goals: Vec<Goal>) -> Vec<Goal> {
73        if !self.enable_reordering || goals.len() <= 1 {
74            return goals;
75        }
76
77        self.stats.total_optimizations += 1;
78
79        // Estimate selectivity for each goal
80        let mut goal_selectivity: Vec<(Goal, f64)> = goals
81            .into_iter()
82            .map(|g| {
83                let selectivity = self.estimate_selectivity(&g);
84                (g, selectivity)
85            })
86            .collect();
87
88        // Sort by selectivity (lower = more selective = evaluate first)
89        goal_selectivity.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
90
91        // Extract optimized goals
92        let optimized: Vec<Goal> = goal_selectivity.into_iter().map(|(g, _)| g).collect();
93
94        self.stats.goals_reordered += optimized.len();
95
96        optimized
97    }
98
99    /// Estimate selectivity of a goal (lower = more selective)
100    ///
101    /// Returns a value between 0.0 (very selective) and 1.0 (not selective)
102    pub fn estimate_selectivity(&self, goal: &Goal) -> f64 {
103        // Check if we have a known selectivity estimate
104        if let Some(&selectivity) = self.selectivity_map.get(&goal.pattern) {
105            return selectivity;
106        }
107
108        // Heuristic-based estimation
109        self.heuristic_selectivity(goal)
110    }
111
112    /// Heuristic-based selectivity estimation
113    fn heuristic_selectivity(&self, goal: &Goal) -> f64 {
114        let pattern = &goal.pattern;
115
116        // Count bound vs unbound variables
117        let (bound_count, var_count) = self.count_variables(pattern);
118
119        if var_count == 0 {
120            // No variables = most selective (exact match)
121            return 0.1;
122        }
123
124        // More bound variables = more selective
125        let bound_ratio = bound_count as f64 / var_count as f64;
126        let selectivity = 1.0 - (bound_ratio * 0.8);
127
128        // Check for specific patterns
129        if pattern.contains("in_stock") || pattern.contains("available") {
130            // Stock checks typically very selective
131            return selectivity * 0.3;
132        }
133
134        if pattern.contains("expensive") || pattern.contains("premium") {
135            // Price filters moderately selective
136            return selectivity * 0.5;
137        }
138
139        if pattern.contains("item") || pattern.contains("product") {
140            // Generic predicates less selective
141            return selectivity * 1.2;
142        }
143
144        selectivity
145    }
146
147    /// Count bound and total variables in a pattern
148    fn count_variables(&self, pattern: &str) -> (usize, usize) {
149        let mut bound = 0;
150        let mut total = 0;
151
152        let chars: Vec<char> = pattern.chars().collect();
153        let mut i = 0;
154
155        while i < chars.len() {
156            if chars[i] == '?' {
157                total += 1;
158
159                // Skip variable name
160                i += 1;
161                while i < chars.len() && (chars[i].is_alphanumeric() || chars[i] == '_') {
162                    i += 1;
163                }
164
165                // Check if followed by a comparison (indicates bound)
166                while i < chars.len() && chars[i].is_whitespace() {
167                    i += 1;
168                }
169
170                if i < chars.len() && (chars[i] == '=' || chars[i] == '>' || chars[i] == '<') {
171                    bound += 1;
172                }
173            } else {
174                i += 1;
175            }
176        }
177
178        (bound, total)
179    }
180
181    /// Set selectivity estimate for a predicate
182    pub fn set_selectivity(&mut self, predicate: String, selectivity: f64) {
183        self.selectivity_map.insert(predicate, selectivity.clamp(0.0, 1.0));
184    }
185
186    /// Get optimization statistics
187    pub fn stats(&self) -> &OptimizationStats {
188        &self.stats
189    }
190
191    /// Reset statistics
192    pub fn reset_stats(&mut self) {
193        self.stats = OptimizationStats::new();
194    }
195
196    /// Enable or disable goal reordering
197    pub fn set_reordering(&mut self, enabled: bool) {
198        self.enable_reordering = enabled;
199    }
200
201    /// Enable or disable index selection
202    pub fn set_index_selection(&mut self, enabled: bool) {
203        self.enable_index_selection = enabled;
204    }
205
206    /// Enable or disable memoization
207    pub fn set_memoization(&mut self, enabled: bool) {
208        self.enable_memoization = enabled;
209    }
210}
211
212impl Default for QueryOptimizer {
213    fn default() -> Self {
214        Self::new()
215    }
216}
217
218/// Configuration for the query optimizer
219#[derive(Debug, Clone)]
220pub struct OptimizerConfig {
221    /// Enable goal reordering
222    pub enable_reordering: bool,
223
224    /// Enable index selection
225    pub enable_index_selection: bool,
226
227    /// Enable memoization
228    pub enable_memoization: bool,
229}
230
231impl Default for OptimizerConfig {
232    fn default() -> Self {
233        Self {
234            enable_reordering: true,
235            enable_index_selection: true,
236            enable_memoization: true,
237        }
238    }
239}
240
241/// Statistics for query optimization
242#[derive(Debug, Clone, Default)]
243pub struct OptimizationStats {
244    /// Total number of optimizations performed
245    pub total_optimizations: usize,
246
247    /// Total goals reordered
248    pub goals_reordered: usize,
249
250    /// Number of index selections made
251    pub index_selections: usize,
252
253    /// Number of memoization hits
254    pub memoization_hits: usize,
255
256    /// Number of memoization misses
257    pub memoization_misses: usize,
258}
259
260impl OptimizationStats {
261    /// Create new statistics
262    pub fn new() -> Self {
263        Self::default()
264    }
265
266    /// Get memoization hit rate
267    pub fn memoization_hit_rate(&self) -> f64 {
268        let total = self.memoization_hits + self.memoization_misses;
269        if total == 0 {
270            0.0
271        } else {
272            self.memoization_hits as f64 / total as f64
273        }
274    }
275
276    /// Get summary string
277    pub fn summary(&self) -> String {
278        format!(
279            "Optimizations: {} | Goals reordered: {} | Memo hits: {} ({:.1}%)",
280            self.total_optimizations,
281            self.goals_reordered,
282            self.memoization_hits,
283            self.memoization_hit_rate() * 100.0
284        )
285    }
286}
287
288/// Join ordering optimizer
289#[derive(Debug, Clone)]
290pub struct JoinOptimizer {
291    /// Cost estimates for different join strategies
292    cost_model: HashMap<String, f64>,
293}
294
295impl JoinOptimizer {
296    /// Create a new join optimizer
297    pub fn new() -> Self {
298        Self {
299            cost_model: HashMap::new(),
300        }
301    }
302
303    /// Optimize join order for a set of goals
304    ///
305    /// Returns goals ordered for optimal join performance
306    pub fn optimize_joins(&self, goals: Vec<Goal>) -> Vec<Goal> {
307        if goals.len() <= 1 {
308            return goals;
309        }
310
311        // Simple heuristic: start with goals that have most bound variables
312        let mut sorted_goals = goals;
313        sorted_goals.sort_by_key(|g| {
314            // Count bound variables (negative for descending sort)
315            -(self.count_bound_vars(&g.pattern) as i32)
316        });
317
318        sorted_goals
319    }
320
321    /// Count bound variables in a pattern
322    fn count_bound_vars(&self, pattern: &str) -> usize {
323        // Simple heuristic: variables followed by comparison operators
324        let mut count = 0;
325        let chars: Vec<char> = pattern.chars().collect();
326        let mut i = 0;
327
328        while i < chars.len() {
329            if chars[i] == '?' {
330                // Skip variable name
331                i += 1;
332                while i < chars.len() && (chars[i].is_alphanumeric() || chars[i] == '_') {
333                    i += 1;
334                }
335
336                // Check if followed by comparison
337                while i < chars.len() && chars[i].is_whitespace() {
338                    i += 1;
339                }
340
341                if i < chars.len() && (chars[i] == '=' || chars[i] == '>' || chars[i] == '<') {
342                    count += 1;
343                }
344            } else {
345                i += 1;
346            }
347        }
348
349        count
350    }
351}
352
353impl Default for JoinOptimizer {
354    fn default() -> Self {
355        Self::new()
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[test]
364    fn test_optimizer_creation() {
365        let optimizer = QueryOptimizer::new();
366        assert!(optimizer.enable_reordering);
367        assert!(optimizer.enable_index_selection);
368        assert!(optimizer.enable_memoization);
369    }
370
371    #[test]
372    fn test_optimizer_with_config() {
373        let config = OptimizerConfig {
374            enable_reordering: false,
375            enable_index_selection: true,
376            enable_memoization: false,
377        };
378
379        let optimizer = QueryOptimizer::with_config(config);
380        assert!(!optimizer.enable_reordering);
381        assert!(optimizer.enable_index_selection);
382        assert!(!optimizer.enable_memoization);
383    }
384
385    #[test]
386    fn test_goal_reordering() {
387        let mut optimizer = QueryOptimizer::new();
388
389        // Set selectivity estimates (lower = more selective)
390        optimizer.set_selectivity("in_stock(?x)".to_string(), 0.1);
391        optimizer.set_selectivity("expensive(?x)".to_string(), 0.3);
392        optimizer.set_selectivity("item(?x)".to_string(), 0.9);
393
394        let goals = vec![
395            Goal::new("item(?x)".to_string()),
396            Goal::new("expensive(?x)".to_string()),
397            Goal::new("in_stock(?x)".to_string()),
398        ];
399
400        let optimized = optimizer.optimize_goals(goals);
401
402        // Should be ordered: in_stock, expensive, item
403        assert_eq!(optimized[0].pattern, "in_stock(?x)");
404        assert_eq!(optimized[1].pattern, "expensive(?x)");
405        assert_eq!(optimized[2].pattern, "item(?x)");
406    }
407
408    #[test]
409    fn test_selectivity_estimation() {
410        let optimizer = QueryOptimizer::new();
411
412        // Exact match (no variables) - very selective
413        let goal1 = Goal::new("employee(alice)".to_string());
414        let sel1 = optimizer.estimate_selectivity(&goal1);
415        assert!(sel1 < 0.5);
416
417        // One unbound variable - less selective
418        let goal2 = Goal::new("employee(?x)".to_string());
419        let sel2 = optimizer.estimate_selectivity(&goal2);
420        assert!(sel2 > sel1);
421
422        // Bound variable (with comparison) - more selective
423        let goal3 = Goal::new("salary(?x) WHERE ?x > 100000".to_string());
424        let sel3 = optimizer.estimate_selectivity(&goal3);
425        // Should be more selective than fully unbound
426        assert!(sel3 < sel2);
427    }
428
429    #[test]
430    fn test_count_variables() {
431        let optimizer = QueryOptimizer::new();
432
433        // No variables
434        let (bound, total) = optimizer.count_variables("employee(alice)");
435        assert_eq!(total, 0);
436        assert_eq!(bound, 0);
437
438        // One unbound variable
439        let (bound, total) = optimizer.count_variables("employee(?x)");
440        assert_eq!(total, 1);
441        assert_eq!(bound, 0);
442
443        // One bound variable
444        let (bound, total) = optimizer.count_variables("salary(?x) WHERE ?x > 100");
445        assert_eq!(total, 2); // ?x appears twice
446        assert_eq!(bound, 1); // Second ?x is bound by >
447    }
448
449    #[test]
450    fn test_optimization_stats() {
451        let mut optimizer = QueryOptimizer::new();
452
453        let goals = vec![
454            Goal::new("a(?x)".to_string()),
455            Goal::new("b(?x)".to_string()),
456        ];
457
458        optimizer.optimize_goals(goals);
459
460        let stats = optimizer.stats();
461        assert_eq!(stats.total_optimizations, 1);
462        assert_eq!(stats.goals_reordered, 2);
463    }
464
465    #[test]
466    fn test_join_optimizer() {
467        let optimizer = JoinOptimizer::new();
468
469        let goals = vec![
470            Goal::new("item(?x)".to_string()),
471            Goal::new("price(?x, ?p) WHERE ?p > 100".to_string()),
472            Goal::new("in_stock(?x)".to_string()),
473        ];
474
475        let optimized = optimizer.optimize_joins(goals);
476
477        // Goal with bound variable should come first
478        assert!(optimized[0].pattern.contains("?p > 100"));
479    }
480
481    #[test]
482    fn test_disable_reordering() {
483        let mut optimizer = QueryOptimizer::new();
484        optimizer.set_reordering(false);
485
486        let goals = vec![
487            Goal::new("a(?x)".to_string()),
488            Goal::new("b(?x)".to_string()),
489            Goal::new("c(?x)".to_string()),
490        ];
491
492        let optimized = optimizer.optimize_goals(goals.clone());
493
494        // Order should be unchanged
495        assert_eq!(optimized[0].pattern, goals[0].pattern);
496        assert_eq!(optimized[1].pattern, goals[1].pattern);
497        assert_eq!(optimized[2].pattern, goals[2].pattern);
498    }
499
500    #[test]
501    fn test_stats_summary() {
502        let mut stats = OptimizationStats::new();
503        stats.total_optimizations = 10;
504        stats.goals_reordered = 25;
505        stats.memoization_hits = 8;
506        stats.memoization_misses = 2;
507
508        let summary = stats.summary();
509        assert!(summary.contains("10"));
510        assert!(summary.contains("25"));
511        assert!(summary.contains("8"));
512        assert!(summary.contains("80")); // 80% hit rate
513    }
514
515    #[test]
516    fn test_memoization_hit_rate() {
517        let mut stats = OptimizationStats::new();
518
519        // No data yet
520        assert_eq!(stats.memoization_hit_rate(), 0.0);
521
522        // 80% hit rate
523        stats.memoization_hits = 8;
524        stats.memoization_misses = 2;
525        assert!((stats.memoization_hit_rate() - 0.8).abs() < 0.01);
526
527        // 100% hit rate
528        stats.memoization_hits = 10;
529        stats.memoization_misses = 0;
530        assert_eq!(stats.memoization_hit_rate(), 1.0);
531    }
532}