1use super::ast::*;
11use std::collections::HashSet;
12
13#[derive(Debug, Clone)]
15pub struct OptimizationPlan {
16 pub optimized_query: Query,
17 pub optimizations_applied: Vec<OptimizationType>,
18 pub estimated_cost: f64,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum OptimizationType {
23 PredicatePushdown,
24 JoinReordering,
25 ConstantFolding,
26 IndexHint,
27 EarlyFiltering,
28 PatternSimplification,
29 DeadCodeElimination,
30}
31
32pub struct QueryOptimizer {
33 enable_predicate_pushdown: bool,
34 enable_join_reordering: bool,
35 enable_constant_folding: bool,
36}
37
38impl QueryOptimizer {
39 pub fn new() -> Self {
40 Self {
41 enable_predicate_pushdown: true,
42 enable_join_reordering: true,
43 enable_constant_folding: true,
44 }
45 }
46
47 pub fn optimize(&self, query: Query) -> OptimizationPlan {
49 let mut optimized = query;
50 let mut optimizations = Vec::new();
51
52 if self.enable_constant_folding {
54 if let Some(q) = self.apply_constant_folding(optimized.clone()) {
55 optimized = q;
56 optimizations.push(OptimizationType::ConstantFolding);
57 }
58 }
59
60 if self.enable_predicate_pushdown {
61 if let Some(q) = self.apply_predicate_pushdown(optimized.clone()) {
62 optimized = q;
63 optimizations.push(OptimizationType::PredicatePushdown);
64 }
65 }
66
67 if self.enable_join_reordering {
68 if let Some(q) = self.apply_join_reordering(optimized.clone()) {
69 optimized = q;
70 optimizations.push(OptimizationType::JoinReordering);
71 }
72 }
73
74 let cost = self.estimate_cost(&optimized);
75
76 OptimizationPlan {
77 optimized_query: optimized,
78 optimizations_applied: optimizations,
79 estimated_cost: cost,
80 }
81 }
82
83 fn apply_constant_folding(&self, mut query: Query) -> Option<Query> {
85 let mut changed = false;
86
87 for statement in &mut query.statements {
88 if self.fold_statement(statement) {
89 changed = true;
90 }
91 }
92
93 if changed {
94 Some(query)
95 } else {
96 None
97 }
98 }
99
100 fn fold_statement(&self, statement: &mut Statement) -> bool {
101 match statement {
102 Statement::Match(clause) => {
103 let mut changed = false;
104 if let Some(where_clause) = &mut clause.where_clause {
105 if let Some(folded) = self.fold_expression(&where_clause.condition) {
106 where_clause.condition = folded;
107 changed = true;
108 }
109 }
110 changed
111 }
112 Statement::Return(clause) => {
113 let mut changed = false;
114 for item in &mut clause.items {
115 if let Some(folded) = self.fold_expression(&item.expression) {
116 item.expression = folded;
117 changed = true;
118 }
119 }
120 changed
121 }
122 _ => false,
123 }
124 }
125
126 fn fold_expression(&self, expr: &Expression) -> Option<Expression> {
127 match expr {
128 Expression::BinaryOp { left, op, right } => {
129 let left = self
131 .fold_expression(left)
132 .unwrap_or_else(|| (**left).clone());
133 let right = self
134 .fold_expression(right)
135 .unwrap_or_else(|| (**right).clone());
136
137 if left.is_constant() && right.is_constant() {
139 return self.evaluate_constant_binary_op(&left, *op, &right);
140 }
141
142 Some(Expression::BinaryOp {
144 left: Box::new(left),
145 op: *op,
146 right: Box::new(right),
147 })
148 }
149 Expression::UnaryOp { op, operand } => {
150 let operand = self
151 .fold_expression(operand)
152 .unwrap_or_else(|| (**operand).clone());
153
154 if operand.is_constant() {
155 return self.evaluate_constant_unary_op(*op, &operand);
156 }
157
158 Some(Expression::UnaryOp {
159 op: *op,
160 operand: Box::new(operand),
161 })
162 }
163 _ => None,
164 }
165 }
166
167 fn evaluate_constant_binary_op(
168 &self,
169 left: &Expression,
170 op: BinaryOperator,
171 right: &Expression,
172 ) -> Option<Expression> {
173 match (left, op, right) {
174 (Expression::Integer(a), BinaryOperator::Add, Expression::Integer(b)) => {
176 Some(Expression::Integer(a + b))
177 }
178 (Expression::Integer(a), BinaryOperator::Subtract, Expression::Integer(b)) => {
179 Some(Expression::Integer(a - b))
180 }
181 (Expression::Integer(a), BinaryOperator::Multiply, Expression::Integer(b)) => {
182 Some(Expression::Integer(a * b))
183 }
184 (Expression::Integer(a), BinaryOperator::Divide, Expression::Integer(b)) if *b != 0 => {
185 Some(Expression::Integer(a / b))
186 }
187 (Expression::Integer(a), BinaryOperator::Modulo, Expression::Integer(b)) if *b != 0 => {
188 Some(Expression::Integer(a % b))
189 }
190 (Expression::Float(a), BinaryOperator::Add, Expression::Float(b)) => {
191 Some(Expression::Float(a + b))
192 }
193 (Expression::Float(a), BinaryOperator::Subtract, Expression::Float(b)) => {
194 Some(Expression::Float(a - b))
195 }
196 (Expression::Float(a), BinaryOperator::Multiply, Expression::Float(b)) => {
197 Some(Expression::Float(a * b))
198 }
199 (Expression::Float(a), BinaryOperator::Divide, Expression::Float(b)) if *b != 0.0 => {
200 Some(Expression::Float(a / b))
201 }
202 (Expression::Integer(a), BinaryOperator::Equal, Expression::Integer(b)) => {
204 Some(Expression::Boolean(a == b))
205 }
206 (Expression::Integer(a), BinaryOperator::NotEqual, Expression::Integer(b)) => {
207 Some(Expression::Boolean(a != b))
208 }
209 (Expression::Integer(a), BinaryOperator::LessThan, Expression::Integer(b)) => {
210 Some(Expression::Boolean(a < b))
211 }
212 (Expression::Integer(a), BinaryOperator::LessThanOrEqual, Expression::Integer(b)) => {
213 Some(Expression::Boolean(a <= b))
214 }
215 (Expression::Integer(a), BinaryOperator::GreaterThan, Expression::Integer(b)) => {
216 Some(Expression::Boolean(a > b))
217 }
218 (
219 Expression::Integer(a),
220 BinaryOperator::GreaterThanOrEqual,
221 Expression::Integer(b),
222 ) => Some(Expression::Boolean(a >= b)),
223 (Expression::Float(a), BinaryOperator::Equal, Expression::Float(b)) => {
225 Some(Expression::Boolean((a - b).abs() < f64::EPSILON))
226 }
227 (Expression::Float(a), BinaryOperator::NotEqual, Expression::Float(b)) => {
228 Some(Expression::Boolean((a - b).abs() >= f64::EPSILON))
229 }
230 (Expression::Float(a), BinaryOperator::LessThan, Expression::Float(b)) => {
231 Some(Expression::Boolean(a < b))
232 }
233 (Expression::Float(a), BinaryOperator::LessThanOrEqual, Expression::Float(b)) => {
234 Some(Expression::Boolean(a <= b))
235 }
236 (Expression::Float(a), BinaryOperator::GreaterThan, Expression::Float(b)) => {
237 Some(Expression::Boolean(a > b))
238 }
239 (Expression::Float(a), BinaryOperator::GreaterThanOrEqual, Expression::Float(b)) => {
240 Some(Expression::Boolean(a >= b))
241 }
242 (Expression::String(a), BinaryOperator::Equal, Expression::String(b)) => {
244 Some(Expression::Boolean(a == b))
245 }
246 (Expression::String(a), BinaryOperator::NotEqual, Expression::String(b)) => {
247 Some(Expression::Boolean(a != b))
248 }
249 (Expression::Boolean(a), BinaryOperator::And, Expression::Boolean(b)) => {
251 Some(Expression::Boolean(*a && *b))
252 }
253 (Expression::Boolean(a), BinaryOperator::Or, Expression::Boolean(b)) => {
254 Some(Expression::Boolean(*a || *b))
255 }
256 (Expression::Boolean(a), BinaryOperator::Equal, Expression::Boolean(b)) => {
257 Some(Expression::Boolean(a == b))
258 }
259 (Expression::Boolean(a), BinaryOperator::NotEqual, Expression::Boolean(b)) => {
260 Some(Expression::Boolean(a != b))
261 }
262 _ => None,
263 }
264 }
265
266 fn evaluate_constant_unary_op(
267 &self,
268 op: UnaryOperator,
269 operand: &Expression,
270 ) -> Option<Expression> {
271 match (op, operand) {
272 (UnaryOperator::Not, Expression::Boolean(b)) => Some(Expression::Boolean(!b)),
273 (UnaryOperator::Minus, Expression::Integer(n)) => Some(Expression::Integer(-n)),
274 (UnaryOperator::Minus, Expression::Float(n)) => Some(Expression::Float(-n)),
275 _ => None,
276 }
277 }
278
279 fn apply_predicate_pushdown(&self, query: Query) -> Option<Query> {
282 None
288 }
289
290 fn apply_join_reordering(&self, query: Query) -> Option<Query> {
292 let mut optimized = query.clone();
296 let mut changed = false;
297
298 for statement in &mut optimized.statements {
299 if let Statement::Match(clause) = statement {
300 let mut patterns = clause.patterns.clone();
301
302 patterns.sort_by_key(|p| {
304 let selectivity = self.estimate_pattern_selectivity(p);
305 -(selectivity * 1000.0) as i64
307 });
308
309 if patterns != clause.patterns {
310 clause.patterns = patterns;
311 changed = true;
312 }
313 }
314 }
315
316 if changed {
317 Some(optimized)
318 } else {
319 None
320 }
321 }
322
323 fn estimate_pattern_selectivity(&self, pattern: &Pattern) -> f64 {
325 match pattern {
326 Pattern::Node(node) => {
327 let mut selectivity = 0.3; selectivity += node.labels.len() as f64 * 0.1;
331
332 if let Some(props) = &node.properties {
334 selectivity += props.len() as f64 * 0.15;
335 }
336
337 selectivity.min(1.0)
338 }
339 Pattern::Relationship(rel) => {
340 let mut selectivity = 0.2; if rel.rel_type.is_some() {
344 selectivity += 0.2;
345 }
346
347 if let Some(props) = &rel.properties {
349 selectivity += props.len() as f64 * 0.15;
350 }
351
352 selectivity +=
354 self.estimate_pattern_selectivity(&Pattern::Node(*rel.from.clone())) * 0.3;
355 selectivity += self.estimate_pattern_selectivity(&*rel.to) * 0.3;
357
358 selectivity.min(1.0)
359 }
360 Pattern::Hyperedge(hyperedge) => {
361 let mut selectivity = 0.5; selectivity += hyperedge.arity as f64 * 0.1;
365
366 if let Some(props) = &hyperedge.properties {
367 selectivity += props.len() as f64 * 0.15;
368 }
369
370 selectivity.min(1.0)
371 }
372 Pattern::Path(_) => 0.1, }
374 }
375
376 fn estimate_cost(&self, query: &Query) -> f64 {
378 let mut cost = 0.0;
379
380 for statement in &query.statements {
381 cost += self.estimate_statement_cost(statement);
382 }
383
384 cost
385 }
386
387 fn estimate_statement_cost(&self, statement: &Statement) -> f64 {
388 match statement {
389 Statement::Match(clause) => {
390 let mut cost = 0.0;
391
392 for pattern in &clause.patterns {
393 cost += self.estimate_pattern_cost(pattern);
394 }
395
396 if clause.where_clause.is_some() {
398 cost *= 1.2;
399 }
400
401 cost
402 }
403 Statement::Create(clause) => {
404 clause.patterns.len() as f64 * 50.0
406 }
407 Statement::Merge(clause) => {
408 self.estimate_pattern_cost(&clause.pattern) * 2.0
410 }
411 Statement::Delete(_) => 30.0,
412 Statement::Set(_) => 20.0,
413 Statement::Remove(clause) => clause.items.len() as f64 * 15.0,
414 Statement::Return(clause) => {
415 let mut cost = 10.0;
416
417 for item in &clause.items {
419 if item.expression.has_aggregation() {
420 cost += 50.0;
421 }
422 }
423
424 if clause.order_by.is_some() {
426 cost += 100.0;
427 }
428
429 cost
430 }
431 Statement::With(_) => 15.0,
432 }
433 }
434
435 fn estimate_pattern_cost(&self, pattern: &Pattern) -> f64 {
436 match pattern {
437 Pattern::Node(node) => {
438 let mut cost = 100.0;
439
440 cost /= (1.0 + node.labels.len() as f64 * 0.5);
442
443 if let Some(props) = &node.properties {
445 cost /= (1.0 + props.len() as f64 * 0.3);
446 }
447
448 cost
449 }
450 Pattern::Relationship(rel) => {
451 let mut cost = 200.0; if rel.rel_type.is_some() {
455 cost *= 0.7;
456 }
457
458 if let Some(range) = &rel.range {
460 let max = range.max.unwrap_or(10);
461 cost *= max as f64;
462 }
463
464 cost
465 }
466 Pattern::Hyperedge(hyperedge) => {
467 150.0 * hyperedge.arity as f64
469 }
470 Pattern::Path(_) => 300.0, }
472 }
473
474 fn get_variables_in_expression(&self, expr: &Expression) -> HashSet<String> {
476 let mut vars = HashSet::new();
477 self.collect_variables(expr, &mut vars);
478 vars
479 }
480
481 fn collect_variables(&self, expr: &Expression, vars: &mut HashSet<String>) {
482 match expr {
483 Expression::Variable(name) => {
484 vars.insert(name.clone());
485 }
486 Expression::Property { object, .. } => {
487 self.collect_variables(object, vars);
488 }
489 Expression::BinaryOp { left, right, .. } => {
490 self.collect_variables(left, vars);
491 self.collect_variables(right, vars);
492 }
493 Expression::UnaryOp { operand, .. } => {
494 self.collect_variables(operand, vars);
495 }
496 Expression::FunctionCall { args, .. } => {
497 for arg in args {
498 self.collect_variables(arg, vars);
499 }
500 }
501 Expression::Aggregation { expression, .. } => {
502 self.collect_variables(expression, vars);
503 }
504 Expression::List(items) => {
505 for item in items {
506 self.collect_variables(item, vars);
507 }
508 }
509 Expression::Case {
510 expression,
511 alternatives,
512 default,
513 } => {
514 if let Some(expr) = expression {
515 self.collect_variables(expr, vars);
516 }
517 for (cond, result) in alternatives {
518 self.collect_variables(cond, vars);
519 self.collect_variables(result, vars);
520 }
521 if let Some(default_expr) = default {
522 self.collect_variables(default_expr, vars);
523 }
524 }
525 _ => {}
526 }
527 }
528}
529
530impl Default for QueryOptimizer {
531 fn default() -> Self {
532 Self::new()
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539 use crate::cypher::parser::parse_cypher;
540
541 #[test]
542 fn test_constant_folding() {
543 let query = parse_cypher("MATCH (n) WHERE 2 + 3 = 5 RETURN n").unwrap();
544 let optimizer = QueryOptimizer::new();
545 let plan = optimizer.optimize(query);
546
547 assert!(plan
548 .optimizations_applied
549 .contains(&OptimizationType::ConstantFolding));
550 }
551
552 #[test]
553 fn test_cost_estimation() {
554 let query = parse_cypher("MATCH (n:Person {age: 30}) RETURN n").unwrap();
555 let optimizer = QueryOptimizer::new();
556 let cost = optimizer.estimate_cost(&query);
557
558 assert!(cost > 0.0);
559 }
560
561 #[test]
562 fn test_pattern_selectivity() {
563 let optimizer = QueryOptimizer::new();
564
565 let node_with_label = Pattern::Node(NodePattern {
566 variable: Some("n".to_string()),
567 labels: vec!["Person".to_string()],
568 properties: None,
569 });
570
571 let node_without_label = Pattern::Node(NodePattern {
572 variable: Some("n".to_string()),
573 labels: vec![],
574 properties: None,
575 });
576
577 let sel_with = optimizer.estimate_pattern_selectivity(&node_with_label);
578 let sel_without = optimizer.estimate_pattern_selectivity(&node_without_label);
579
580 assert!(sel_with > sel_without);
581 }
582}