ruvector_graph/cypher/
optimizer.rs

1//! Query optimizer for Cypher queries
2//!
3//! Optimizes query execution plans through:
4//! - Predicate pushdown (filter as early as possible)
5//! - Join reordering (minimize intermediate results)
6//! - Index utilization
7//! - Constant folding
8//! - Dead code elimination
9
10use super::ast::*;
11use std::collections::HashSet;
12
13/// Query optimization plan
14#[derive(Debug, Clone)]
15pub struct OptimizationPlan {
16    pub optimized_query: Query,
17    pub optimizations_applied: Vec<OptimizationType>,
18    pub estimated_cost: f64,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum OptimizationType {
23    PredicatePushdown,
24    JoinReordering,
25    ConstantFolding,
26    IndexHint,
27    EarlyFiltering,
28    PatternSimplification,
29    DeadCodeElimination,
30}
31
32pub struct QueryOptimizer {
33    enable_predicate_pushdown: bool,
34    enable_join_reordering: bool,
35    enable_constant_folding: bool,
36}
37
38impl QueryOptimizer {
39    pub fn new() -> Self {
40        Self {
41            enable_predicate_pushdown: true,
42            enable_join_reordering: true,
43            enable_constant_folding: true,
44        }
45    }
46
47    /// Optimize a query and return an execution plan
48    pub fn optimize(&self, query: Query) -> OptimizationPlan {
49        let mut optimized = query;
50        let mut optimizations = Vec::new();
51
52        // Apply optimizations in order
53        if self.enable_constant_folding {
54            if let Some(q) = self.apply_constant_folding(optimized.clone()) {
55                optimized = q;
56                optimizations.push(OptimizationType::ConstantFolding);
57            }
58        }
59
60        if self.enable_predicate_pushdown {
61            if let Some(q) = self.apply_predicate_pushdown(optimized.clone()) {
62                optimized = q;
63                optimizations.push(OptimizationType::PredicatePushdown);
64            }
65        }
66
67        if self.enable_join_reordering {
68            if let Some(q) = self.apply_join_reordering(optimized.clone()) {
69                optimized = q;
70                optimizations.push(OptimizationType::JoinReordering);
71            }
72        }
73
74        let cost = self.estimate_cost(&optimized);
75
76        OptimizationPlan {
77            optimized_query: optimized,
78            optimizations_applied: optimizations,
79            estimated_cost: cost,
80        }
81    }
82
83    /// Apply constant folding to simplify expressions
84    fn apply_constant_folding(&self, mut query: Query) -> Option<Query> {
85        let mut changed = false;
86
87        for statement in &mut query.statements {
88            if self.fold_statement(statement) {
89                changed = true;
90            }
91        }
92
93        if changed {
94            Some(query)
95        } else {
96            None
97        }
98    }
99
100    fn fold_statement(&self, statement: &mut Statement) -> bool {
101        match statement {
102            Statement::Match(clause) => {
103                let mut changed = false;
104                if let Some(where_clause) = &mut clause.where_clause {
105                    if let Some(folded) = self.fold_expression(&where_clause.condition) {
106                        where_clause.condition = folded;
107                        changed = true;
108                    }
109                }
110                changed
111            }
112            Statement::Return(clause) => {
113                let mut changed = false;
114                for item in &mut clause.items {
115                    if let Some(folded) = self.fold_expression(&item.expression) {
116                        item.expression = folded;
117                        changed = true;
118                    }
119                }
120                changed
121            }
122            _ => false,
123        }
124    }
125
126    fn fold_expression(&self, expr: &Expression) -> Option<Expression> {
127        match expr {
128            Expression::BinaryOp { left, op, right } => {
129                // Fold operands first
130                let left = self
131                    .fold_expression(left)
132                    .unwrap_or_else(|| (**left).clone());
133                let right = self
134                    .fold_expression(right)
135                    .unwrap_or_else(|| (**right).clone());
136
137                // Try to evaluate constant expressions
138                if left.is_constant() && right.is_constant() {
139                    return self.evaluate_constant_binary_op(&left, *op, &right);
140                }
141
142                // Return simplified expression
143                Some(Expression::BinaryOp {
144                    left: Box::new(left),
145                    op: *op,
146                    right: Box::new(right),
147                })
148            }
149            Expression::UnaryOp { op, operand } => {
150                let operand = self
151                    .fold_expression(operand)
152                    .unwrap_or_else(|| (**operand).clone());
153
154                if operand.is_constant() {
155                    return self.evaluate_constant_unary_op(*op, &operand);
156                }
157
158                Some(Expression::UnaryOp {
159                    op: *op,
160                    operand: Box::new(operand),
161                })
162            }
163            _ => None,
164        }
165    }
166
167    fn evaluate_constant_binary_op(
168        &self,
169        left: &Expression,
170        op: BinaryOperator,
171        right: &Expression,
172    ) -> Option<Expression> {
173        match (left, op, right) {
174            // Arithmetic operations
175            (Expression::Integer(a), BinaryOperator::Add, Expression::Integer(b)) => {
176                Some(Expression::Integer(a + b))
177            }
178            (Expression::Integer(a), BinaryOperator::Subtract, Expression::Integer(b)) => {
179                Some(Expression::Integer(a - b))
180            }
181            (Expression::Integer(a), BinaryOperator::Multiply, Expression::Integer(b)) => {
182                Some(Expression::Integer(a * b))
183            }
184            (Expression::Integer(a), BinaryOperator::Divide, Expression::Integer(b)) if *b != 0 => {
185                Some(Expression::Integer(a / b))
186            }
187            (Expression::Integer(a), BinaryOperator::Modulo, Expression::Integer(b)) if *b != 0 => {
188                Some(Expression::Integer(a % b))
189            }
190            (Expression::Float(a), BinaryOperator::Add, Expression::Float(b)) => {
191                Some(Expression::Float(a + b))
192            }
193            (Expression::Float(a), BinaryOperator::Subtract, Expression::Float(b)) => {
194                Some(Expression::Float(a - b))
195            }
196            (Expression::Float(a), BinaryOperator::Multiply, Expression::Float(b)) => {
197                Some(Expression::Float(a * b))
198            }
199            (Expression::Float(a), BinaryOperator::Divide, Expression::Float(b)) if *b != 0.0 => {
200                Some(Expression::Float(a / b))
201            }
202            // Comparison operations for integers
203            (Expression::Integer(a), BinaryOperator::Equal, Expression::Integer(b)) => {
204                Some(Expression::Boolean(a == b))
205            }
206            (Expression::Integer(a), BinaryOperator::NotEqual, Expression::Integer(b)) => {
207                Some(Expression::Boolean(a != b))
208            }
209            (Expression::Integer(a), BinaryOperator::LessThan, Expression::Integer(b)) => {
210                Some(Expression::Boolean(a < b))
211            }
212            (Expression::Integer(a), BinaryOperator::LessThanOrEqual, Expression::Integer(b)) => {
213                Some(Expression::Boolean(a <= b))
214            }
215            (Expression::Integer(a), BinaryOperator::GreaterThan, Expression::Integer(b)) => {
216                Some(Expression::Boolean(a > b))
217            }
218            (
219                Expression::Integer(a),
220                BinaryOperator::GreaterThanOrEqual,
221                Expression::Integer(b),
222            ) => Some(Expression::Boolean(a >= b)),
223            // Comparison operations for floats
224            (Expression::Float(a), BinaryOperator::Equal, Expression::Float(b)) => {
225                Some(Expression::Boolean((a - b).abs() < f64::EPSILON))
226            }
227            (Expression::Float(a), BinaryOperator::NotEqual, Expression::Float(b)) => {
228                Some(Expression::Boolean((a - b).abs() >= f64::EPSILON))
229            }
230            (Expression::Float(a), BinaryOperator::LessThan, Expression::Float(b)) => {
231                Some(Expression::Boolean(a < b))
232            }
233            (Expression::Float(a), BinaryOperator::LessThanOrEqual, Expression::Float(b)) => {
234                Some(Expression::Boolean(a <= b))
235            }
236            (Expression::Float(a), BinaryOperator::GreaterThan, Expression::Float(b)) => {
237                Some(Expression::Boolean(a > b))
238            }
239            (Expression::Float(a), BinaryOperator::GreaterThanOrEqual, Expression::Float(b)) => {
240                Some(Expression::Boolean(a >= b))
241            }
242            // String comparison
243            (Expression::String(a), BinaryOperator::Equal, Expression::String(b)) => {
244                Some(Expression::Boolean(a == b))
245            }
246            (Expression::String(a), BinaryOperator::NotEqual, Expression::String(b)) => {
247                Some(Expression::Boolean(a != b))
248            }
249            // Boolean operations
250            (Expression::Boolean(a), BinaryOperator::And, Expression::Boolean(b)) => {
251                Some(Expression::Boolean(*a && *b))
252            }
253            (Expression::Boolean(a), BinaryOperator::Or, Expression::Boolean(b)) => {
254                Some(Expression::Boolean(*a || *b))
255            }
256            (Expression::Boolean(a), BinaryOperator::Equal, Expression::Boolean(b)) => {
257                Some(Expression::Boolean(a == b))
258            }
259            (Expression::Boolean(a), BinaryOperator::NotEqual, Expression::Boolean(b)) => {
260                Some(Expression::Boolean(a != b))
261            }
262            _ => None,
263        }
264    }
265
266    fn evaluate_constant_unary_op(
267        &self,
268        op: UnaryOperator,
269        operand: &Expression,
270    ) -> Option<Expression> {
271        match (op, operand) {
272            (UnaryOperator::Not, Expression::Boolean(b)) => Some(Expression::Boolean(!b)),
273            (UnaryOperator::Minus, Expression::Integer(n)) => Some(Expression::Integer(-n)),
274            (UnaryOperator::Minus, Expression::Float(n)) => Some(Expression::Float(-n)),
275            _ => None,
276        }
277    }
278
279    /// Apply predicate pushdown optimization
280    /// Move WHERE clauses as close to data access as possible
281    fn apply_predicate_pushdown(&self, query: Query) -> Option<Query> {
282        // In a real implementation, this would analyze the query graph
283        // and push predicates down to the earliest possible point
284        // For now, we'll do a simple transformation
285
286        // This is a placeholder - real implementation would be more complex
287        None
288    }
289
290    /// Reorder joins to minimize intermediate result sizes
291    fn apply_join_reordering(&self, query: Query) -> Option<Query> {
292        // Analyze pattern complexity and reorder based on selectivity
293        // Patterns with more constraints should be evaluated first
294
295        let mut optimized = query.clone();
296        let mut changed = false;
297
298        for statement in &mut optimized.statements {
299            if let Statement::Match(clause) = statement {
300                let mut patterns = clause.patterns.clone();
301
302                // Sort patterns by estimated selectivity (more selective first)
303                patterns.sort_by_key(|p| {
304                    let selectivity = self.estimate_pattern_selectivity(p);
305                    // Use negative to sort in descending order (most selective first)
306                    -(selectivity * 1000.0) as i64
307                });
308
309                if patterns != clause.patterns {
310                    clause.patterns = patterns;
311                    changed = true;
312                }
313            }
314        }
315
316        if changed {
317            Some(optimized)
318        } else {
319            None
320        }
321    }
322
323    /// Estimate the selectivity of a pattern (0.0 = least selective, 1.0 = most selective)
324    fn estimate_pattern_selectivity(&self, pattern: &Pattern) -> f64 {
325        match pattern {
326            Pattern::Node(node) => {
327                let mut selectivity = 0.3; // Base selectivity for node
328
329                // More labels = more selective
330                selectivity += node.labels.len() as f64 * 0.1;
331
332                // Properties = more selective
333                if let Some(props) = &node.properties {
334                    selectivity += props.len() as f64 * 0.15;
335                }
336
337                selectivity.min(1.0)
338            }
339            Pattern::Relationship(rel) => {
340                let mut selectivity = 0.2; // Base selectivity for relationship
341
342                // Specific type = more selective
343                if rel.rel_type.is_some() {
344                    selectivity += 0.2;
345                }
346
347                // Properties = more selective
348                if let Some(props) = &rel.properties {
349                    selectivity += props.len() as f64 * 0.15;
350                }
351
352                // Add selectivity from connected nodes
353                selectivity +=
354                    self.estimate_pattern_selectivity(&Pattern::Node(*rel.from.clone())) * 0.3;
355                // rel.to is now a Pattern (can be NodePattern or chained RelationshipPattern)
356                selectivity += self.estimate_pattern_selectivity(&*rel.to) * 0.3;
357
358                selectivity.min(1.0)
359            }
360            Pattern::Hyperedge(hyperedge) => {
361                let mut selectivity = 0.5; // Hyperedges are typically more selective
362
363                // More nodes involved = more selective
364                selectivity += hyperedge.arity as f64 * 0.1;
365
366                if let Some(props) = &hyperedge.properties {
367                    selectivity += props.len() as f64 * 0.15;
368                }
369
370                selectivity.min(1.0)
371            }
372            Pattern::Path(_) => 0.1, // Paths are typically less selective
373        }
374    }
375
376    /// Estimate the cost of executing a query
377    fn estimate_cost(&self, query: &Query) -> f64 {
378        let mut cost = 0.0;
379
380        for statement in &query.statements {
381            cost += self.estimate_statement_cost(statement);
382        }
383
384        cost
385    }
386
387    fn estimate_statement_cost(&self, statement: &Statement) -> f64 {
388        match statement {
389            Statement::Match(clause) => {
390                let mut cost = 0.0;
391
392                for pattern in &clause.patterns {
393                    cost += self.estimate_pattern_cost(pattern);
394                }
395
396                // WHERE clause adds filtering cost
397                if clause.where_clause.is_some() {
398                    cost *= 1.2;
399                }
400
401                cost
402            }
403            Statement::Create(clause) => {
404                // Create operations are expensive
405                clause.patterns.len() as f64 * 50.0
406            }
407            Statement::Merge(clause) => {
408                // Merge is more expensive than match or create alone
409                self.estimate_pattern_cost(&clause.pattern) * 2.0
410            }
411            Statement::Delete(_) => 30.0,
412            Statement::Set(_) => 20.0,
413            Statement::Remove(clause) => clause.items.len() as f64 * 15.0,
414            Statement::Return(clause) => {
415                let mut cost = 10.0;
416
417                // Aggregations are expensive
418                for item in &clause.items {
419                    if item.expression.has_aggregation() {
420                        cost += 50.0;
421                    }
422                }
423
424                // Sorting adds cost
425                if clause.order_by.is_some() {
426                    cost += 100.0;
427                }
428
429                cost
430            }
431            Statement::With(_) => 15.0,
432        }
433    }
434
435    fn estimate_pattern_cost(&self, pattern: &Pattern) -> f64 {
436        match pattern {
437            Pattern::Node(node) => {
438                let mut cost = 100.0;
439
440                // Labels reduce cost (more selective)
441                cost /= (1.0 + node.labels.len() as f64 * 0.5);
442
443                // Properties reduce cost
444                if let Some(props) = &node.properties {
445                    cost /= (1.0 + props.len() as f64 * 0.3);
446                }
447
448                cost
449            }
450            Pattern::Relationship(rel) => {
451                let mut cost = 200.0; // Relationships are more expensive
452
453                // Specific type reduces cost
454                if rel.rel_type.is_some() {
455                    cost *= 0.7;
456                }
457
458                // Variable length paths are very expensive
459                if let Some(range) = &rel.range {
460                    let max = range.max.unwrap_or(10);
461                    cost *= max as f64;
462                }
463
464                cost
465            }
466            Pattern::Hyperedge(hyperedge) => {
467                // Hyperedges are more expensive due to N-ary nature
468                150.0 * hyperedge.arity as f64
469            }
470            Pattern::Path(_) => 300.0, // Paths can be expensive
471        }
472    }
473
474    /// Get variables used in an expression
475    fn get_variables_in_expression(&self, expr: &Expression) -> HashSet<String> {
476        let mut vars = HashSet::new();
477        self.collect_variables(expr, &mut vars);
478        vars
479    }
480
481    fn collect_variables(&self, expr: &Expression, vars: &mut HashSet<String>) {
482        match expr {
483            Expression::Variable(name) => {
484                vars.insert(name.clone());
485            }
486            Expression::Property { object, .. } => {
487                self.collect_variables(object, vars);
488            }
489            Expression::BinaryOp { left, right, .. } => {
490                self.collect_variables(left, vars);
491                self.collect_variables(right, vars);
492            }
493            Expression::UnaryOp { operand, .. } => {
494                self.collect_variables(operand, vars);
495            }
496            Expression::FunctionCall { args, .. } => {
497                for arg in args {
498                    self.collect_variables(arg, vars);
499                }
500            }
501            Expression::Aggregation { expression, .. } => {
502                self.collect_variables(expression, vars);
503            }
504            Expression::List(items) => {
505                for item in items {
506                    self.collect_variables(item, vars);
507                }
508            }
509            Expression::Case {
510                expression,
511                alternatives,
512                default,
513            } => {
514                if let Some(expr) = expression {
515                    self.collect_variables(expr, vars);
516                }
517                for (cond, result) in alternatives {
518                    self.collect_variables(cond, vars);
519                    self.collect_variables(result, vars);
520                }
521                if let Some(default_expr) = default {
522                    self.collect_variables(default_expr, vars);
523                }
524            }
525            _ => {}
526        }
527    }
528}
529
530impl Default for QueryOptimizer {
531    fn default() -> Self {
532        Self::new()
533    }
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539    use crate::cypher::parser::parse_cypher;
540
541    #[test]
542    fn test_constant_folding() {
543        let query = parse_cypher("MATCH (n) WHERE 2 + 3 = 5 RETURN n").unwrap();
544        let optimizer = QueryOptimizer::new();
545        let plan = optimizer.optimize(query);
546
547        assert!(plan
548            .optimizations_applied
549            .contains(&OptimizationType::ConstantFolding));
550    }
551
552    #[test]
553    fn test_cost_estimation() {
554        let query = parse_cypher("MATCH (n:Person {age: 30}) RETURN n").unwrap();
555        let optimizer = QueryOptimizer::new();
556        let cost = optimizer.estimate_cost(&query);
557
558        assert!(cost > 0.0);
559    }
560
561    #[test]
562    fn test_pattern_selectivity() {
563        let optimizer = QueryOptimizer::new();
564
565        let node_with_label = Pattern::Node(NodePattern {
566            variable: Some("n".to_string()),
567            labels: vec!["Person".to_string()],
568            properties: None,
569        });
570
571        let node_without_label = Pattern::Node(NodePattern {
572            variable: Some("n".to_string()),
573            labels: vec![],
574            properties: None,
575        });
576
577        let sel_with = optimizer.estimate_pattern_selectivity(&node_with_label);
578        let sel_without = optimizer.estimate_pattern_selectivity(&node_without_label);
579
580        assert!(sel_with > sel_without);
581    }
582}