Skip to main content

xlog_logic/
optimizer.rs

1//! Query optimizer for join ordering and predicate pushdown.
2//!
3//! This module provides cost-based query optimization for XLOG's relational IR.
4//! It uses GPU-resident statistics from [`xlog_stats::StatsManager`] to make
5//! informed decisions about:
6//!
7//! - **Predicate pushdown**: Moving filter predicates closer to base scans to
8//!   reduce intermediate result sizes early in the pipeline.
9//! - **Cost estimation**: Computing expected row counts, CPU costs, GPU memory
10//!   usage, and data transfer counts for plan nodes.
11//! - **Join ordering**: (Future) Reordering joins based on selectivity estimates
12//!   to minimize intermediate result sizes.
13//!
14//! # Usage
15//!
16//! ```ignore
17//! use std::sync::Arc;
18//! use xlog_logic::optimizer::{Optimizer, OptimizerConfig, PlanCost};
19//! use xlog_stats::StatsManager;
20//!
21//! let stats = Arc::new(StatsManager::new());
22//! let optimizer = Optimizer::new(stats);
23//!
24//! // Optimize a query plan
25//! let optimized_plan = optimizer.optimize(original_plan);
26//!
27//! // Get cost estimates
28//! let cost = optimizer.estimate_cost(&optimized_plan);
29//! println!("Estimated rows: {}, GPU memory: {} bytes", cost.rows, cost.gpu_mem);
30//! ```
31
32use std::collections::HashMap;
33use std::sync::Arc;
34use xlog_core::{RelId, Schema};
35use xlog_ir::{CompareOp, Expr, JoinType, RirNode};
36use xlog_stats::StatsManager;
37
38/// Configuration for query optimization.
39///
40/// Controls optimizer behavior including thresholds for algorithm selection
41/// and feature toggles.
42#[derive(Debug, Clone)]
43#[non_exhaustive]
44pub struct OptimizerConfig {
45    /// Maximum number of relations for exhaustive dynamic programming.
46    ///
47    /// When a query involves more relations than this threshold, the optimizer
48    /// switches to a greedy algorithm for join ordering to avoid exponential
49    /// time complexity. Default: 10 relations.
50    pub dp_threshold: usize,
51
52    /// Heat threshold for recommending index creation.
53    ///
54    /// Relations with access heat above this threshold are candidates for
55    /// index building to accelerate future queries. Default: 0.7.
56    pub index_heat_threshold: f32,
57
58    /// Enable predicate pushdown optimization.
59    ///
60    /// When enabled, filter predicates are pushed down through projections
61    /// and joins to be applied as early as possible. Default: true.
62    pub enable_pushdown: bool,
63
64    /// Default selectivity for filters when no statistics are available.
65    ///
66    /// Used as a fallback when column statistics cannot provide a better
67    /// estimate. Default: 0.1 (10% selectivity).
68    pub default_filter_selectivity: f64,
69
70    /// Cost multiplier for GPU-to-host data transfers.
71    ///
72    /// Transfers are expensive operations; this multiplier reflects the
73    /// relative cost compared to local GPU operations. Default: 100.0.
74    pub transfer_cost_multiplier: f64,
75
76    /// Bytes per row used for GPU memory estimation when schema is unknown.
77    ///
78    /// Default: 32 bytes (assumes 4 columns at 8 bytes each on average).
79    pub default_bytes_per_row: u64,
80}
81
82impl Default for OptimizerConfig {
83    fn default() -> Self {
84        Self {
85            dp_threshold: 10,
86            index_heat_threshold: 0.7,
87            enable_pushdown: true,
88            default_filter_selectivity: 0.1,
89            transfer_cost_multiplier: 100.0,
90            default_bytes_per_row: 32,
91        }
92    }
93}
94
95/// Cost estimate for a query plan node.
96///
97/// Captures the multi-dimensional cost of executing a plan node, enabling
98/// the optimizer to make informed decisions based on available resources.
99#[derive(Debug, Clone, Default, PartialEq)]
100pub struct PlanCost {
101    /// Estimated number of output rows.
102    pub rows: u64,
103
104    /// Estimated CPU cost (arbitrary units, relative comparisons only).
105    ///
106    /// This represents processing overhead that cannot be parallelized on
107    /// the GPU, such as coordination, scheduling, and result materialization.
108    pub cpu_cost: f64,
109
110    /// Estimated GPU memory usage in bytes.
111    ///
112    /// Includes both input buffers and intermediate storage required for
113    /// the operation.
114    pub gpu_mem: u64,
115
116    /// Number of GPU-to-host or host-to-GPU data transfers.
117    ///
118    /// Transfers are typically the most expensive operations in GPU computing
119    /// and should be minimized.
120    pub transfers: u32,
121}
122
123impl PlanCost {
124    /// Creates a new cost estimate with specified row count.
125    pub fn with_rows(rows: u64) -> Self {
126        Self {
127            rows,
128            ..Default::default()
129        }
130    }
131
132    /// Computes a scalar cost value for comparison purposes.
133    ///
134    /// The formula weights different cost components:
135    /// - CPU cost is taken directly
136    /// - GPU memory is scaled by 0.001 (1GB = 1M cost units)
137    /// - Transfers are heavily weighted due to their high latency
138    ///
139    /// # Arguments
140    ///
141    /// * `transfer_weight` - Weight multiplier for transfer costs
142    pub fn total_cost(&self, transfer_weight: f64) -> f64 {
143        self.cpu_cost + (self.gpu_mem as f64 * 0.001) + (self.transfers as f64 * transfer_weight)
144    }
145
146    /// Combines two costs representing sequential operations.
147    ///
148    /// Row count comes from the second (later) operation; other costs are summed.
149    pub fn then(self, other: PlanCost) -> PlanCost {
150        PlanCost {
151            rows: other.rows,
152            cpu_cost: self.cpu_cost + other.cpu_cost,
153            gpu_mem: self.gpu_mem.max(other.gpu_mem), // Peak memory usage
154            transfers: self.transfers + other.transfers,
155        }
156    }
157}
158
159/// Query optimizer using statistics for cost-based decisions.
160///
161/// The optimizer transforms query plans to improve execution efficiency
162/// by applying rewrites like predicate pushdown and using statistics to
163/// estimate costs for different plan alternatives.
164pub struct Optimizer {
165    stats: Arc<StatsManager>,
166    config: OptimizerConfig,
167    /// Schemas for relations, keyed by RelId
168    schemas: HashMap<RelId, Schema>,
169}
170
171impl Optimizer {
172    /// Creates a new optimizer with default configuration.
173    ///
174    /// # Arguments
175    ///
176    /// * `stats` - Shared statistics manager for cardinality and selectivity estimates
177    pub fn new(stats: Arc<StatsManager>) -> Self {
178        Self {
179            stats,
180            config: OptimizerConfig::default(),
181            schemas: HashMap::new(),
182        }
183    }
184
185    /// Creates a new optimizer with custom configuration.
186    ///
187    /// # Arguments
188    ///
189    /// * `stats` - Shared statistics manager
190    /// * `config` - Custom optimizer configuration
191    pub fn with_config(stats: Arc<StatsManager>, config: OptimizerConfig) -> Self {
192        Self {
193            stats,
194            config,
195            schemas: HashMap::new(),
196        }
197    }
198
199    /// Sets the schemas for relations.
200    ///
201    /// This information is used by the optimizer to accurately determine
202    /// column widths during predicate pushdown.
203    pub fn set_schemas(&mut self, schemas: HashMap<RelId, Schema>) {
204        self.schemas = schemas;
205    }
206
207    /// Returns a reference to the current configuration.
208    pub fn config(&self) -> &OptimizerConfig {
209        &self.config
210    }
211
212    /// Returns a reference to the statistics manager.
213    pub fn stats(&self) -> &Arc<StatsManager> {
214        &self.stats
215    }
216
217    /// Optimizes an execution plan by applying transformation rules.
218    ///
219    /// Currently applies:
220    /// - Predicate pushdown (if enabled)
221    ///
222    /// Future optimizations may include:
223    /// - Join reordering based on cardinality estimates
224    /// - Projection pushdown
225    /// - Common subexpression elimination
226    ///
227    /// # Arguments
228    ///
229    /// * `node` - The plan to optimize
230    ///
231    /// # Returns
232    ///
233    /// An optimized plan that is semantically equivalent to the input
234    pub fn optimize(&self, node: RirNode) -> RirNode {
235        if self.config.enable_pushdown {
236            self.predicate_pushdown(node)
237        } else {
238            node
239        }
240    }
241
242    /// Pushes filter predicates closer to scan nodes.
243    ///
244    /// This transformation reduces intermediate result sizes by applying
245    /// filters as early as possible in the query pipeline. The rules are:
246    ///
247    /// - Filters can be pushed through projections (with column remapping)
248    /// - Filters can be pushed into one or both sides of a join if the
249    ///   predicate references only columns from that side
250    /// - Filters on join keys can inform join selectivity estimates
251    ///
252    /// # Arguments
253    ///
254    /// * `node` - The plan node to transform
255    ///
256    /// # Returns
257    ///
258    /// The transformed plan with predicates pushed down where beneficial
259    fn predicate_pushdown(&self, node: RirNode) -> RirNode {
260        match node {
261            // Base case: scan nodes cannot be transformed further
262            RirNode::Unit => RirNode::Unit,
263            RirNode::Scan { rel } => RirNode::Scan { rel },
264
265            // Filter on top of another node: try to push down
266            RirNode::Filter { input, predicate } => {
267                // First, recursively optimize the input
268                let optimized_input = self.predicate_pushdown(*input);
269
270                match optimized_input {
271                    // Filter on Filter: merge predicates
272                    RirNode::Filter {
273                        input: inner_input,
274                        predicate: inner_pred,
275                    } => {
276                        let merged = Expr::And(vec![inner_pred, predicate]);
277                        RirNode::Filter {
278                            input: inner_input,
279                            predicate: merged,
280                        }
281                    }
282
283                    // Filter on Project: push through if possible
284                    RirNode::Project {
285                        input: proj_input,
286                        columns,
287                    } => {
288                        // Check if predicate only references pass-through columns
289                        if let Some(remapped) =
290                            self.remap_predicate_through_project(&predicate, &columns)
291                        {
292                            // Push the remapped predicate below the projection
293                            RirNode::Project {
294                                input: Box::new(RirNode::Filter {
295                                    input: proj_input,
296                                    predicate: remapped,
297                                }),
298                                columns,
299                            }
300                        } else {
301                            // Cannot push: keep filter above
302                            RirNode::Filter {
303                                input: Box::new(RirNode::Project {
304                                    input: proj_input,
305                                    columns,
306                                }),
307                                predicate,
308                            }
309                        }
310                    }
311
312                    // Filter on Join: try to push to appropriate side
313                    RirNode::Join {
314                        left,
315                        right,
316                        left_keys,
317                        right_keys,
318                        join_type,
319                    } => {
320                        let left_width = self.estimate_width(&left);
321                        let (left_preds, right_preds, remaining) =
322                            self.split_predicate_for_join(&predicate, left_width);
323
324                        // Apply pushed predicates to each side
325                        let new_left = if !left_preds.is_empty() {
326                            Box::new(RirNode::Filter {
327                                input: left,
328                                predicate: Self::conjoin(left_preds),
329                            })
330                        } else {
331                            left
332                        };
333
334                        let new_right = if !right_preds.is_empty() {
335                            Box::new(RirNode::Filter {
336                                input: right,
337                                predicate: Self::conjoin(right_preds),
338                            })
339                        } else {
340                            right
341                        };
342
343                        let join_node = RirNode::Join {
344                            left: new_left,
345                            right: new_right,
346                            left_keys,
347                            right_keys,
348                            join_type,
349                        };
350
351                        // Apply remaining predicates that couldn't be pushed
352                        if !remaining.is_empty() {
353                            RirNode::Filter {
354                                input: Box::new(join_node),
355                                predicate: Self::conjoin(remaining),
356                            }
357                        } else {
358                            join_node
359                        }
360                    }
361
362                    // Default: cannot push further
363                    other => RirNode::Filter {
364                        input: Box::new(other),
365                        predicate,
366                    },
367                }
368            }
369
370            // Project: recursively optimize input
371            RirNode::Project { input, columns } => RirNode::Project {
372                input: Box::new(self.predicate_pushdown(*input)),
373                columns,
374            },
375
376            // Join: recursively optimize both sides
377            RirNode::Join {
378                left,
379                right,
380                left_keys,
381                right_keys,
382                join_type,
383            } => RirNode::Join {
384                left: Box::new(self.predicate_pushdown(*left)),
385                right: Box::new(self.predicate_pushdown(*right)),
386                left_keys,
387                right_keys,
388                join_type,
389            },
390
391            // GroupBy: recursively optimize input
392            RirNode::GroupBy {
393                input,
394                key_cols,
395                aggs,
396            } => RirNode::GroupBy {
397                input: Box::new(self.predicate_pushdown(*input)),
398                key_cols,
399                aggs,
400            },
401
402            // Union: recursively optimize all inputs
403            RirNode::Union { inputs } => RirNode::Union {
404                inputs: inputs
405                    .into_iter()
406                    .map(|i| self.predicate_pushdown(i))
407                    .collect(),
408            },
409
410            // Distinct: recursively optimize input
411            RirNode::Distinct { input, key_cols } => RirNode::Distinct {
412                input: Box::new(self.predicate_pushdown(*input)),
413                key_cols,
414            },
415
416            // Diff: recursively optimize both sides
417            RirNode::Diff { left, right } => RirNode::Diff {
418                left: Box::new(self.predicate_pushdown(*left)),
419                right: Box::new(self.predicate_pushdown(*right)),
420            },
421
422            // Fixpoint: recursively optimize base and recursive parts
423            RirNode::Fixpoint {
424                scc_id,
425                base,
426                recursive,
427                delta_rel,
428                full_rel,
429            } => RirNode::Fixpoint {
430                scc_id,
431                base: Box::new(self.predicate_pushdown(*base)),
432                recursive: Box::new(self.predicate_pushdown(*recursive)),
433                delta_rel,
434                full_rel,
435            },
436
437            RirNode::TensorMaskedJoin { .. } => node, // Leaf-like: no pushdown
438
439            // Promoted physical-shape nodes are produced after the
440            // optimizer runs. Required for compile safety and as a
441            // no-op fallback if the call order ever changes.
442            RirNode::MultiWayJoin { .. } | RirNode::ChainJoin { .. } => node,
443        }
444    }
445
446    /// Attempts to remap a predicate through a projection.
447    ///
448    /// Returns `Some(remapped_predicate)` if all column references in the
449    /// predicate can be traced back through pass-through columns.
450    /// Returns `None` if the predicate references computed columns.
451    fn remap_predicate_through_project(
452        &self,
453        predicate: &Expr,
454        columns: &[xlog_ir::ProjectExpr],
455    ) -> Option<Expr> {
456        // Build a mapping from output column index to input column index
457        // Only for pass-through columns
458        let mut output_to_input: std::collections::HashMap<usize, usize> =
459            std::collections::HashMap::new();
460
461        for (out_idx, proj_expr) in columns.iter().enumerate() {
462            if let xlog_ir::ProjectExpr::Column(in_idx) = proj_expr {
463                output_to_input.insert(out_idx, *in_idx);
464            }
465        }
466
467        self.remap_expr(predicate, &output_to_input)
468    }
469
470    /// Recursively remaps column references in an expression.
471    fn remap_expr(
472        &self,
473        expr: &Expr,
474        mapping: &std::collections::HashMap<usize, usize>,
475    ) -> Option<Expr> {
476        match expr {
477            Expr::Column(idx) => mapping.get(idx).map(|&new_idx| Expr::Column(new_idx)),
478
479            Expr::Const(val) => Some(Expr::Const(val.clone())),
480
481            Expr::Compare { left, op, right } => {
482                let new_left = self.remap_expr(left, mapping)?;
483                let new_right = self.remap_expr(right, mapping)?;
484                Some(Expr::Compare {
485                    left: Box::new(new_left),
486                    op: *op,
487                    right: Box::new(new_right),
488                })
489            }
490
491            Expr::And(exprs) => {
492                let remapped: Option<Vec<_>> =
493                    exprs.iter().map(|e| self.remap_expr(e, mapping)).collect();
494                remapped.map(Expr::And)
495            }
496
497            Expr::Or(exprs) => {
498                let remapped: Option<Vec<_>> =
499                    exprs.iter().map(|e| self.remap_expr(e, mapping)).collect();
500                remapped.map(Expr::Or)
501            }
502
503            Expr::Not(inner) => {
504                let remapped = self.remap_expr(inner, mapping)?;
505                Some(Expr::Not(Box::new(remapped)))
506            }
507
508            // Arithmetic operations
509            Expr::Add(l, r) => {
510                let new_l = self.remap_expr(l, mapping)?;
511                let new_r = self.remap_expr(r, mapping)?;
512                Some(Expr::Add(Box::new(new_l), Box::new(new_r)))
513            }
514            Expr::Sub(l, r) => {
515                let new_l = self.remap_expr(l, mapping)?;
516                let new_r = self.remap_expr(r, mapping)?;
517                Some(Expr::Sub(Box::new(new_l), Box::new(new_r)))
518            }
519            Expr::Mul(l, r) => {
520                let new_l = self.remap_expr(l, mapping)?;
521                let new_r = self.remap_expr(r, mapping)?;
522                Some(Expr::Mul(Box::new(new_l), Box::new(new_r)))
523            }
524            Expr::Div(l, r) => {
525                let new_l = self.remap_expr(l, mapping)?;
526                let new_r = self.remap_expr(r, mapping)?;
527                Some(Expr::Div(Box::new(new_l), Box::new(new_r)))
528            }
529            Expr::Mod(l, r) => {
530                let new_l = self.remap_expr(l, mapping)?;
531                let new_r = self.remap_expr(r, mapping)?;
532                Some(Expr::Mod(Box::new(new_l), Box::new(new_r)))
533            }
534
535            // Built-in functions
536            Expr::Abs(inner) => {
537                let remapped = self.remap_expr(inner, mapping)?;
538                Some(Expr::Abs(Box::new(remapped)))
539            }
540            Expr::Min(l, r) => {
541                let new_l = self.remap_expr(l, mapping)?;
542                let new_r = self.remap_expr(r, mapping)?;
543                Some(Expr::Min(Box::new(new_l), Box::new(new_r)))
544            }
545            Expr::Max(l, r) => {
546                let new_l = self.remap_expr(l, mapping)?;
547                let new_r = self.remap_expr(r, mapping)?;
548                Some(Expr::Max(Box::new(new_l), Box::new(new_r)))
549            }
550            Expr::Pow(l, r) => {
551                let new_l = self.remap_expr(l, mapping)?;
552                let new_r = self.remap_expr(r, mapping)?;
553                Some(Expr::Pow(Box::new(new_l), Box::new(new_r)))
554            }
555            Expr::Cast(inner, scalar_type) => {
556                let remapped = self.remap_expr(inner, mapping)?;
557                Some(Expr::Cast(Box::new(remapped), *scalar_type))
558            }
559            Expr::Conditional {
560                condition,
561                then_expr,
562                else_expr,
563            } => {
564                let new_condition = self.remap_expr(condition, mapping)?;
565                let new_then = self.remap_expr(then_expr, mapping)?;
566                let new_else = self.remap_expr(else_expr, mapping)?;
567                Some(Expr::Conditional {
568                    condition: Box::new(new_condition),
569                    then_expr: Box::new(new_then),
570                    else_expr: Box::new(new_else),
571                })
572            }
573        }
574    }
575
576    /// Estimates the output width (number of columns) of a plan node.
577    fn estimate_width(&self, node: &RirNode) -> usize {
578        match node {
579            RirNode::Unit => 0,
580            RirNode::Scan { rel } => {
581                // Use schema if available, otherwise stats, otherwise default
582                if let Some(schema) = self.schemas.get(rel) {
583                    schema.arity()
584                } else if let Some(stats) = self.stats.get_relation_stats(*rel) {
585                    stats.column_stats.len().max(1)
586                } else {
587                    4 // Default assumption
588                }
589            }
590            RirNode::Filter { input, .. } => self.estimate_width(input),
591            RirNode::Project { columns, .. } => columns.len(),
592            RirNode::Join { left, right, .. } => {
593                self.estimate_width(left) + self.estimate_width(right)
594            }
595            RirNode::ChainJoin { output_columns, .. } => output_columns.len(),
596            RirNode::GroupBy { key_cols, aggs, .. } => key_cols.len() + aggs.len(),
597            RirNode::Union { inputs } => {
598                inputs.first().map(|i| self.estimate_width(i)).unwrap_or(0)
599            }
600            RirNode::Distinct { input, .. } => self.estimate_width(input),
601            RirNode::Diff { left, .. } => self.estimate_width(left),
602            RirNode::Fixpoint { base, .. } => self.estimate_width(base),
603            // RD-27: Optimizer schemas are HashMap<RelId, Schema>.
604            // Use head_rel_id (not head_rel_name) for lookup.
605            RirNode::TensorMaskedJoin { head_rel_id, .. } => self
606                .schemas
607                .get(head_rel_id)
608                .map(|s| s.arity())
609                .unwrap_or(2),
610            // v0.6.5: `MultiWayJoin` post-promoter only — width equals
611            // the head projection arity, mirroring the Project arm.
612            RirNode::MultiWayJoin { output_columns, .. } => output_columns.len(),
613        }
614    }
615
616    /// Splits a predicate into parts pushable to left, right, or neither side of a join.
617    ///
618    /// Returns (left_predicates, right_predicates, remaining_predicates).
619    fn split_predicate_for_join(
620        &self,
621        predicate: &Expr,
622        left_width: usize,
623    ) -> (Vec<Expr>, Vec<Expr>, Vec<Expr>) {
624        let mut left_preds = Vec::new();
625        let mut right_preds = Vec::new();
626        let mut remaining = Vec::new();
627
628        // Flatten AND expressions
629        let conjuncts = Self::flatten_and(predicate);
630
631        for conj in conjuncts {
632            let cols = Self::collect_columns(&conj);
633            let max_col = cols.iter().copied().max().unwrap_or(0);
634            let min_col = cols.iter().copied().min().unwrap_or(0);
635
636            if cols.is_empty() {
637                // No columns referenced, can push to either side
638                left_preds.push(conj);
639            } else if max_col < left_width {
640                // All columns from left side
641                left_preds.push(conj);
642            } else if min_col >= left_width {
643                // All columns from right side - need to remap
644                let remapped = Self::remap_columns(&conj, |c| c - left_width);
645                right_preds.push(remapped);
646            } else {
647                // References both sides, cannot push
648                remaining.push(conj);
649            }
650        }
651
652        (left_preds, right_preds, remaining)
653    }
654
655    /// Flattens nested AND expressions into a list of conjuncts.
656    fn flatten_and(expr: &Expr) -> Vec<Expr> {
657        match expr {
658            Expr::And(exprs) => exprs.iter().flat_map(Self::flatten_and).collect(),
659            other => vec![other.clone()],
660        }
661    }
662
663    /// Collects all column indices referenced in an expression.
664    fn collect_columns(expr: &Expr) -> Vec<usize> {
665        match expr {
666            Expr::Column(idx) => vec![*idx],
667            Expr::Const(_) => vec![],
668            Expr::Compare { left, right, .. } => {
669                let mut cols = Self::collect_columns(left);
670                cols.extend(Self::collect_columns(right));
671                cols
672            }
673            Expr::And(exprs) | Expr::Or(exprs) => {
674                exprs.iter().flat_map(Self::collect_columns).collect()
675            }
676            Expr::Not(inner) | Expr::Abs(inner) | Expr::Cast(inner, _) => {
677                Self::collect_columns(inner)
678            }
679            Expr::Add(l, r)
680            | Expr::Sub(l, r)
681            | Expr::Mul(l, r)
682            | Expr::Div(l, r)
683            | Expr::Mod(l, r)
684            | Expr::Min(l, r)
685            | Expr::Max(l, r)
686            | Expr::Pow(l, r) => {
687                let mut cols = Self::collect_columns(l);
688                cols.extend(Self::collect_columns(r));
689                cols
690            }
691            Expr::Conditional {
692                condition,
693                then_expr,
694                else_expr,
695            } => {
696                let mut cols = Self::collect_columns(condition);
697                cols.extend(Self::collect_columns(then_expr));
698                cols.extend(Self::collect_columns(else_expr));
699                cols
700            }
701        }
702    }
703
704    /// Remaps column references in an expression using a transformation function.
705    fn remap_columns<F: Fn(usize) -> usize + Copy>(expr: &Expr, f: F) -> Expr {
706        match expr {
707            Expr::Column(idx) => Expr::Column(f(*idx)),
708            Expr::Const(v) => Expr::Const(v.clone()),
709            Expr::Compare { left, op, right } => Expr::Compare {
710                left: Box::new(Self::remap_columns(left, f)),
711                op: *op,
712                right: Box::new(Self::remap_columns(right, f)),
713            },
714            Expr::And(exprs) => {
715                Expr::And(exprs.iter().map(|e| Self::remap_columns(e, f)).collect())
716            }
717            Expr::Or(exprs) => Expr::Or(exprs.iter().map(|e| Self::remap_columns(e, f)).collect()),
718            Expr::Not(inner) => Expr::Not(Box::new(Self::remap_columns(inner, f))),
719            Expr::Add(l, r) => Expr::Add(
720                Box::new(Self::remap_columns(l, f)),
721                Box::new(Self::remap_columns(r, f)),
722            ),
723            Expr::Sub(l, r) => Expr::Sub(
724                Box::new(Self::remap_columns(l, f)),
725                Box::new(Self::remap_columns(r, f)),
726            ),
727            Expr::Mul(l, r) => Expr::Mul(
728                Box::new(Self::remap_columns(l, f)),
729                Box::new(Self::remap_columns(r, f)),
730            ),
731            Expr::Div(l, r) => Expr::Div(
732                Box::new(Self::remap_columns(l, f)),
733                Box::new(Self::remap_columns(r, f)),
734            ),
735            Expr::Mod(l, r) => Expr::Mod(
736                Box::new(Self::remap_columns(l, f)),
737                Box::new(Self::remap_columns(r, f)),
738            ),
739            Expr::Abs(inner) => Expr::Abs(Box::new(Self::remap_columns(inner, f))),
740            Expr::Min(l, r) => Expr::Min(
741                Box::new(Self::remap_columns(l, f)),
742                Box::new(Self::remap_columns(r, f)),
743            ),
744            Expr::Max(l, r) => Expr::Max(
745                Box::new(Self::remap_columns(l, f)),
746                Box::new(Self::remap_columns(r, f)),
747            ),
748            Expr::Pow(l, r) => Expr::Pow(
749                Box::new(Self::remap_columns(l, f)),
750                Box::new(Self::remap_columns(r, f)),
751            ),
752            Expr::Cast(inner, t) => Expr::Cast(Box::new(Self::remap_columns(inner, f)), *t),
753            Expr::Conditional {
754                condition,
755                then_expr,
756                else_expr,
757            } => Expr::Conditional {
758                condition: Box::new(Self::remap_columns(condition, f)),
759                then_expr: Box::new(Self::remap_columns(then_expr, f)),
760                else_expr: Box::new(Self::remap_columns(else_expr, f)),
761            },
762        }
763    }
764
765    /// Combines a list of predicates into a single AND expression.
766    fn conjoin(predicates: Vec<Expr>) -> Expr {
767        debug_assert!(!predicates.is_empty());
768        if predicates.len() == 1 {
769            predicates.into_iter().next().unwrap()
770        } else {
771            Expr::And(predicates)
772        }
773    }
774
775    /// Estimates the cost of executing a plan node.
776    ///
777    /// Recursively computes cost estimates for the entire plan tree,
778    /// using statistics when available and falling back to heuristics.
779    ///
780    /// # Arguments
781    ///
782    /// * `node` - The plan node to estimate
783    ///
784    /// # Returns
785    ///
786    /// A [`PlanCost`] with estimated rows, CPU cost, GPU memory, and transfers
787    pub fn estimate_cost(&self, node: &RirNode) -> PlanCost {
788        match node {
789            RirNode::Unit => PlanCost {
790                rows: 1,
791                cpu_cost: 0.0,
792                gpu_mem: 0,
793                transfers: 0,
794            },
795            RirNode::Scan { rel } => self.estimate_scan_cost(*rel),
796
797            RirNode::Filter { input, predicate } => {
798                let input_cost = self.estimate_cost(input);
799                self.estimate_filter_cost(input_cost, predicate, input)
800            }
801
802            RirNode::Project { input, columns } => {
803                let input_cost = self.estimate_cost(input);
804                self.estimate_project_cost(input_cost, columns)
805            }
806
807            RirNode::Join {
808                left,
809                right,
810                left_keys,
811                right_keys,
812                join_type,
813            } => {
814                let left_cost = self.estimate_cost(left);
815                let right_cost = self.estimate_cost(right);
816                self.estimate_join_cost(
817                    left_cost, right_cost, left, right, left_keys, right_keys, *join_type,
818                )
819            }
820
821            RirNode::ChainJoin {
822                left,
823                right,
824                left_key,
825                right_key,
826                output_columns,
827                ..
828            } => {
829                let left_cost = self.estimate_cost(left);
830                let right_cost = self.estimate_cost(right);
831                let join_cost = self.estimate_join_cost(
832                    left_cost,
833                    right_cost,
834                    left,
835                    right,
836                    &[*left_key],
837                    &[*right_key],
838                    JoinType::Inner,
839                );
840                self.estimate_project_cost(join_cost, output_columns)
841            }
842
843            RirNode::GroupBy {
844                input,
845                key_cols,
846                aggs,
847            } => {
848                let input_cost = self.estimate_cost(input);
849                self.estimate_groupby_cost(input_cost, key_cols, aggs)
850            }
851
852            RirNode::Union { inputs } => {
853                let costs: Vec<_> = inputs.iter().map(|i| self.estimate_cost(i)).collect();
854                self.estimate_union_cost(costs)
855            }
856
857            RirNode::Distinct { input, key_cols } => {
858                let input_cost = self.estimate_cost(input);
859                self.estimate_distinct_cost(input_cost, key_cols)
860            }
861
862            RirNode::Diff { left, right } => {
863                let left_cost = self.estimate_cost(left);
864                let right_cost = self.estimate_cost(right);
865                self.estimate_diff_cost(left_cost, right_cost)
866            }
867
868            RirNode::Fixpoint {
869                base, recursive, ..
870            } => {
871                let base_cost = self.estimate_cost(base);
872                let recursive_cost = self.estimate_cost(recursive);
873                self.estimate_fixpoint_cost(base_cost, recursive_cost)
874            }
875
876            RirNode::TensorMaskedJoin {
877                max_active_rules, ..
878            } => PlanCost {
879                rows: *max_active_rules as u64,
880                cpu_cost: *max_active_rules as f64 * 100.0,
881                gpu_mem: *max_active_rules as u64 * 1024,
882                transfers: 1,
883            },
884            // v0.6.5: `MultiWayJoin` cost is the sum of input scan costs.
885            // Heuristic only — the post-promoter dispatch decides whether
886            // to run the WCOJ kernel or fall back; cost-model integration
887            // for the multiway operator itself is later-slice work.
888            RirNode::MultiWayJoin { inputs, .. } => {
889                let mut total = PlanCost::default();
890                for inp in inputs {
891                    let c = self.estimate_cost(inp);
892                    total.rows = total.rows.saturating_add(c.rows);
893                    total.cpu_cost += c.cpu_cost;
894                    total.gpu_mem = total.gpu_mem.saturating_add(c.gpu_mem);
895                    total.transfers = total.transfers.saturating_add(c.transfers);
896                }
897                total
898            }
899        }
900    }
901
902    /// Estimates cost for a base relation scan.
903    fn estimate_scan_cost(&self, rel: RelId) -> PlanCost {
904        if let Some(stats) = self.stats.get_relation_stats(rel) {
905            PlanCost {
906                rows: stats.cardinality,
907                cpu_cost: stats.cardinality as f64 * 0.01, // Minimal per-row CPU cost
908                gpu_mem: stats
909                    .byte_size
910                    .max(stats.cardinality * self.config.default_bytes_per_row),
911                transfers: 0, // Data already on GPU
912            }
913        } else {
914            // Default estimates for unknown relations
915            let default_rows = 1000;
916            PlanCost {
917                rows: default_rows,
918                cpu_cost: default_rows as f64 * 0.01,
919                gpu_mem: default_rows * self.config.default_bytes_per_row,
920                transfers: 0,
921            }
922        }
923    }
924
925    /// Estimates cost for a filter operation.
926    fn estimate_filter_cost(
927        &self,
928        input_cost: PlanCost,
929        predicate: &Expr,
930        input: &RirNode,
931    ) -> PlanCost {
932        let selectivity = self.estimate_predicate_selectivity(predicate, input);
933        let output_rows = ((input_cost.rows as f64 * selectivity) as u64).max(1);
934
935        PlanCost {
936            rows: output_rows,
937            cpu_cost: input_cost.cpu_cost + input_cost.rows as f64 * 0.02, // Predicate eval cost
938            gpu_mem: input_cost.gpu_mem, // Filter reuses input memory
939            transfers: input_cost.transfers,
940        }
941    }
942
943    /// Estimates cost for a projection operation.
944    fn estimate_project_cost(
945        &self,
946        input_cost: PlanCost,
947        columns: &[xlog_ir::ProjectExpr],
948    ) -> PlanCost {
949        // Count computed vs pass-through columns
950        let computed_count = columns
951            .iter()
952            .filter(|c| matches!(c, xlog_ir::ProjectExpr::Computed(_, _)))
953            .count();
954
955        // Computed columns add CPU cost
956        let compute_cost = computed_count as f64 * input_cost.rows as f64 * 0.05;
957
958        // Output size may be smaller if fewer columns
959        let output_width_ratio = columns.len() as f64 / (columns.len() + 2) as f64; // Rough estimate
960
961        PlanCost {
962            rows: input_cost.rows,
963            cpu_cost: input_cost.cpu_cost + compute_cost,
964            gpu_mem: (input_cost.gpu_mem as f64 * output_width_ratio) as u64,
965            transfers: input_cost.transfers,
966        }
967    }
968
969    /// Estimates cost for a join operation.
970    #[allow(clippy::too_many_arguments)]
971    fn estimate_join_cost(
972        &self,
973        left_cost: PlanCost,
974        right_cost: PlanCost,
975        left: &RirNode,
976        right: &RirNode,
977        left_keys: &[usize],
978        right_keys: &[usize],
979        join_type: JoinType,
980    ) -> PlanCost {
981        // Semi and Anti joins always produce at most left_cost.rows
982        // Handle these specially before checking stats
983        let output_rows = match join_type {
984            JoinType::Semi => {
985                // At most left side rows, estimate 50% match
986                ((left_cost.rows as f64 * 0.5) as u64).max(1)
987            }
988            JoinType::Anti => {
989                // At most left side rows, estimate 50% don't match
990                ((left_cost.rows as f64 * 0.5) as u64).max(1)
991            }
992            JoinType::Inner | JoinType::LeftOuter => {
993                // Get relation IDs for selectivity lookup
994                let left_rels = left.referenced_relations();
995                let right_rels = right.referenced_relations();
996
997                if left_rels.len() == 1 && right_rels.len() == 1 {
998                    // Simple join between two base relations
999                    let estimated = self.stats.estimate_join_cardinality(
1000                        left_rels[0],
1001                        right_rels[0],
1002                        left_keys,
1003                        right_keys,
1004                    );
1005
1006                    match join_type {
1007                        JoinType::LeftOuter => estimated.max(left_cost.rows),
1008                        _ => estimated,
1009                    }
1010                } else {
1011                    // Multi-way or complex join: use heuristic
1012                    match join_type {
1013                        JoinType::Inner => {
1014                            // Assume 10% selectivity for inner joins
1015                            ((left_cost.rows as f64 * right_cost.rows as f64 * 0.1) as u64).max(1)
1016                        }
1017                        JoinType::LeftOuter => {
1018                            // At least left side rows
1019                            left_cost.rows.max(
1020                                ((left_cost.rows as f64 * right_cost.rows as f64 * 0.1) as u64)
1021                                    .max(1),
1022                            )
1023                        }
1024                        _ => unreachable!(),
1025                    }
1026                }
1027            }
1028        };
1029
1030        // Join CPU cost: hash build + probe
1031        let build_cost = right_cost.rows as f64 * 1.0; // Build hash table
1032        let probe_cost = left_cost.rows as f64 * 0.5; // Probe operations
1033        let cpu_cost = left_cost.cpu_cost + right_cost.cpu_cost + build_cost + probe_cost;
1034
1035        // GPU memory: both inputs plus hash table overhead
1036        let hash_table_overhead = right_cost.gpu_mem / 2; // Approximate hash table size
1037        let gpu_mem = left_cost.gpu_mem + right_cost.gpu_mem + hash_table_overhead;
1038
1039        PlanCost {
1040            rows: output_rows,
1041            cpu_cost,
1042            gpu_mem,
1043            transfers: left_cost.transfers + right_cost.transfers,
1044        }
1045    }
1046
1047    /// Estimates cost for a group-by with aggregation.
1048    fn estimate_groupby_cost(
1049        &self,
1050        input_cost: PlanCost,
1051        key_cols: &[usize],
1052        _aggs: &[(usize, xlog_core::AggOp)],
1053    ) -> PlanCost {
1054        // Estimate distinct groups based on key columns
1055        // Heuristic: sqrt(input_rows) for unknown cardinality
1056        let estimated_groups = if key_cols.is_empty() {
1057            1 // No grouping = single result
1058        } else {
1059            // Rough estimate: assume good reduction
1060            ((input_cost.rows as f64).sqrt() as u64).max(1)
1061        };
1062
1063        PlanCost {
1064            rows: estimated_groups,
1065            cpu_cost: input_cost.cpu_cost + input_cost.rows as f64 * 0.5, // Aggregation cost
1066            gpu_mem: input_cost.gpu_mem + estimated_groups * self.config.default_bytes_per_row,
1067            transfers: input_cost.transfers,
1068        }
1069    }
1070
1071    /// Estimates cost for a union operation.
1072    fn estimate_union_cost(&self, input_costs: Vec<PlanCost>) -> PlanCost {
1073        let total_rows: u64 = input_costs.iter().map(|c| c.rows).sum();
1074        let total_cpu: f64 = input_costs.iter().map(|c| c.cpu_cost).sum();
1075        let max_gpu: u64 = input_costs.iter().map(|c| c.gpu_mem).max().unwrap_or(0);
1076        let total_transfers: u32 = input_costs.iter().map(|c| c.transfers).sum();
1077
1078        PlanCost {
1079            rows: total_rows,
1080            cpu_cost: total_cpu + total_rows as f64 * 0.01, // Concatenation cost
1081            gpu_mem: max_gpu,                               // Can process sequentially
1082            transfers: total_transfers,
1083        }
1084    }
1085
1086    /// Estimates cost for a distinct operation.
1087    fn estimate_distinct_cost(&self, input_cost: PlanCost, _key_cols: &[usize]) -> PlanCost {
1088        // Heuristic: distinct reduces rows by some factor
1089        let estimated_distinct = (input_cost.rows as f64 * 0.7) as u64;
1090
1091        PlanCost {
1092            rows: estimated_distinct.max(1),
1093            cpu_cost: input_cost.cpu_cost + input_cost.rows as f64 * 0.3, // Hash-based dedup
1094            gpu_mem: input_cost.gpu_mem + input_cost.rows * 8,            // Hash set overhead
1095            transfers: input_cost.transfers,
1096        }
1097    }
1098
1099    /// Estimates cost for a set difference operation.
1100    fn estimate_diff_cost(&self, left_cost: PlanCost, right_cost: PlanCost) -> PlanCost {
1101        // Diff removes matching rows from left
1102        let estimated_remaining = (left_cost.rows as f64 * 0.5) as u64;
1103
1104        PlanCost {
1105            rows: estimated_remaining.max(1),
1106            cpu_cost: left_cost.cpu_cost + right_cost.cpu_cost + right_cost.rows as f64 * 0.5,
1107            gpu_mem: left_cost.gpu_mem + right_cost.gpu_mem,
1108            transfers: left_cost.transfers + right_cost.transfers,
1109        }
1110    }
1111
1112    /// Estimates cost for a fixpoint (recursive) operation.
1113    fn estimate_fixpoint_cost(&self, base_cost: PlanCost, recursive_cost: PlanCost) -> PlanCost {
1114        // Fixpoint cost depends on number of iterations
1115        // Heuristic: assume log2(base_rows) iterations
1116        let estimated_iterations = ((base_cost.rows as f64).log2().ceil() as u64).max(1);
1117
1118        PlanCost {
1119            rows: base_cost.rows * estimated_iterations, // Output accumulates
1120            cpu_cost: base_cost.cpu_cost + recursive_cost.cpu_cost * estimated_iterations as f64,
1121            gpu_mem: (base_cost.gpu_mem + recursive_cost.gpu_mem) * 2, // Need delta and full
1122            transfers: base_cost.transfers + recursive_cost.transfers * estimated_iterations as u32,
1123        }
1124    }
1125
1126    /// Estimates selectivity of a predicate expression.
1127    fn estimate_predicate_selectivity(&self, predicate: &Expr, input: &RirNode) -> f64 {
1128        match predicate {
1129            Expr::Compare { left, op, right } => {
1130                self.estimate_compare_selectivity(left, *op, right, input)
1131            }
1132            Expr::And(exprs) => {
1133                // Multiply selectivities (independence assumption)
1134                exprs
1135                    .iter()
1136                    .map(|e| self.estimate_predicate_selectivity(e, input))
1137                    .product()
1138            }
1139            Expr::Or(exprs) => {
1140                // P(A or B) = P(A) + P(B) - P(A)P(B) for independent events
1141                // Simplified: max of selectivities as lower bound
1142                exprs
1143                    .iter()
1144                    .map(|e| self.estimate_predicate_selectivity(e, input))
1145                    .fold(0.0, f64::max)
1146            }
1147            Expr::Not(inner) => 1.0 - self.estimate_predicate_selectivity(inner, input),
1148            _ => self.config.default_filter_selectivity,
1149        }
1150    }
1151
1152    /// Estimates selectivity for a comparison predicate.
1153    fn estimate_compare_selectivity(
1154        &self,
1155        left: &Expr,
1156        op: CompareOp,
1157        right: &Expr,
1158        input: &RirNode,
1159    ) -> f64 {
1160        // Try to get column statistics if comparing column to constant
1161        if let (Expr::Column(col_idx), Expr::Const(_)) | (Expr::Const(_), Expr::Column(col_idx)) =
1162            (left, right)
1163        {
1164            // Find the base relation for this column
1165            if let Some(rel_id) = self.find_column_relation(input, *col_idx) {
1166                if let Some(stats) = self.stats.get_relation_stats(rel_id) {
1167                    if let Some(col_stats) = stats.get_column(*col_idx) {
1168                        return match op {
1169                            CompareOp::Eq => col_stats.equality_selectivity(stats.cardinality),
1170                            CompareOp::Ne => {
1171                                1.0 - col_stats.equality_selectivity(stats.cardinality)
1172                            }
1173                            CompareOp::Lt | CompareOp::Le | CompareOp::Gt | CompareOp::Ge => {
1174                                // Range predicates: estimate ~33% selectivity
1175                                0.33
1176                            }
1177                        };
1178                    }
1179                }
1180            }
1181        }
1182
1183        // Default selectivities by operator
1184        match op {
1185            CompareOp::Eq => 0.1, // 10% for equality
1186            CompareOp::Ne => 0.9, // 90% for inequality
1187            CompareOp::Lt | CompareOp::Le | CompareOp::Gt | CompareOp::Ge => 0.33, // 33% for ranges
1188        }
1189    }
1190
1191    /// Finds the base relation that provides a given column.
1192    fn find_column_relation(&self, node: &RirNode, col_idx: usize) -> Option<RelId> {
1193        match node {
1194            RirNode::Scan { rel } => Some(*rel),
1195            RirNode::Filter { input, .. } => self.find_column_relation(input, col_idx),
1196            RirNode::Project { input, columns } => {
1197                // Trace column through projection
1198                if col_idx < columns.len() {
1199                    if let xlog_ir::ProjectExpr::Column(src_idx) = &columns[col_idx] {
1200                        return self.find_column_relation(input, *src_idx);
1201                    }
1202                }
1203                None
1204            }
1205            RirNode::Join { left, right, .. } => {
1206                let left_width = self.estimate_width(left);
1207                if col_idx < left_width {
1208                    self.find_column_relation(left, col_idx)
1209                } else {
1210                    self.find_column_relation(right, col_idx - left_width)
1211                }
1212            }
1213            // v0.6.5: per slice 1 guardrail — return None for
1214            // `MultiWayJoin`. The promoter runs after the optimizer,
1215            // so this arm is unreachable in production. A half-mapped
1216            // implementation that walked `inputs` via `slot_vars` would
1217            // be more dangerous than `None` for this slice.
1218            RirNode::MultiWayJoin { .. } => None,
1219            _ => None, // Complex cases: give up
1220        }
1221    }
1222
1223    /// Returns relations that should have indexes built based on access heat.
1224    ///
1225    /// This is useful for adaptive query processing where frequently accessed
1226    /// relations benefit from index structures.
1227    pub fn recommend_indexes(&self) -> Vec<RelId> {
1228        self.stats.hot_relations(self.config.index_heat_threshold)
1229    }
1230
1231    /// Returns true if the query involves more relations than the DP threshold.
1232    ///
1233    /// Used to decide between exhaustive and greedy join ordering algorithms.
1234    pub fn should_use_greedy(&self, node: &RirNode) -> bool {
1235        let rels = node.referenced_relations();
1236        let unique_rels: std::collections::HashSet<_> = rels.iter().collect();
1237        unique_rels.len() > self.config.dp_threshold
1238    }
1239}
1240
1241// v0.6.5 slice 3 — selectivity-aware optimizer pass.
1242//
1243// No-op by default. Slice 3 lays the seam; slices 4 / 5 may add real
1244// reordering logic that consults `stats` to pick join orderings on selectivity.
1245//
1246// Walks `plan.rules_by_scc[*].body` and rewrites nodes in place. The default
1247// no-op preserves every existing plan tree byte-for-byte. Tests assert
1248// structural equality for triangle, 4-cycle, and recursive-SCC plans.
1249//
1250// Compile-pipeline ordering: runs between `Optimizer::optimize` and
1251// `xlog_logic::promote::promote_multiway`.
1252pub mod selectivity_pass {
1253    //! v0.6.5 W2.2 — selectivity-driven join reordering for
1254    //! canonical lowered triangle and 4-cycle bodies.
1255    //!
1256    //! ## Behavior
1257    //!
1258    //! For each rule body that matches the canonical lowered
1259    //! triangle or 4-cycle shape, the pass enumerates the valid
1260    //! candidate inner pairings (3 for triangle, 2 for 4-cycle),
1261    //! computes each candidate's
1262    //! `StatsManager::estimate_join_cardinality` with
1263    //! **pair-derived join keys from the shared-variable
1264    //! mapping**, and rewrites the body so the smallest-cost
1265    //! choice is materialized first. Tie → keep the optimizer's
1266    //! existing order (deterministic no-op).
1267    //!
1268    //! ## Safety floor
1269    //!
1270    //! If any input atom for a recognized body has no
1271    //! `StatsManager` entry OR `cardinality == 0`, the body is
1272    //! left unchanged. Recursive deltas / freshly-uploaded
1273    //! relations / unseeded predicates therefore stay on the
1274    //! optimizer's default order until stats are populated.
1275    //!
1276    //! ## Default-fallback edge case
1277    //!
1278    //! `StatsManager::estimate_join_cardinality` returns `u64`
1279    //! with no provenance — the caller cannot tell whether the
1280    //! estimate came from the cached `JoinSelectivity` table,
1281    //! the column-distinct heuristic, or the 10% default
1282    //! fallback. When all input atoms have populated
1283    //! cardinalities but no column statistics, the per-pair
1284    //! estimates may all collapse to the same fallback ratio,
1285    //! making the chosen pairing uninformative. **This is an
1286    //! accepted trade-off**: row-set parity holds regardless of
1287    //! selectivity quality (the rewrite preserves semantics);
1288    //! the integration certs gate on row-set + WCOJ-dispatch
1289    //! correctness, not on optimal pair choice.
1290    //!
1291    //! ## Promoter coordination
1292    //!
1293    //! The slice 1 / slice 2 promoters were extended in W2.2
1294    //! step 2a to accept the canonical *semantic* shape with
1295    //! any valid key combination — they emit
1296    //! `MultiWayJoin.inputs` and `slot_vars` in canonical
1297    //! semantic order regardless of the body's positional
1298    //! layout. Reordered bodies therefore still promote and
1299    //! still dispatch the WCOJ kernel correctly.
1300    use std::collections::HashMap;
1301    use xlog_core::RelId;
1302    use xlog_ir::ExecutionPlan;
1303    use xlog_stats::StatsManager;
1304
1305    /// W2.2: selectivity-driven join reordering for canonical
1306    /// triangle + 4-cycle bodies. See module-level doc.
1307    ///
1308    /// `rel_ids` is the predicate-name → RelId map used to
1309    /// resolve body Scans against `StatsManager` lookups.
1310    /// Production callers pass `Compiler::lowerer().rel_ids()`.
1311    /// Test callers can pass an empty map; with no
1312    /// `StatsManager` entries either, the safety floor leaves
1313    /// every body unchanged (legacy no-op behavior preserved).
1314    pub fn run(plan: &mut ExecutionPlan, stats: &StatsManager, rel_ids: &HashMap<String, RelId>) {
1315        // `rel_ids` is reserved for future shape-extension
1316        // work; the current rewriters operate on RelIds
1317        // directly from the body's Scans, so the map isn't
1318        // consulted here. Production callers still pass it
1319        // so the API surface is forward-compatible.
1320        let _ = rel_ids;
1321        for rules in plan.rules_by_scc.iter_mut() {
1322            for rule in rules.iter_mut() {
1323                if let Some(rewritten) = super::reorder::try_reorder_triangle(&rule.body, stats) {
1324                    rule.body = rewritten;
1325                    continue;
1326                }
1327                if let Some(rewritten) = super::reorder::try_reorder_4cycle(&rule.body, stats) {
1328                    rule.body = rewritten;
1329                }
1330            }
1331        }
1332    }
1333}
1334
1335/// W3.7 AOT helper-relation splitting for deep joins with buried skew.
1336pub mod helper_split_pass {
1337    use std::collections::{HashMap, HashSet};
1338
1339    use xlog_core::{RelId, ScalarType, Schema};
1340    use xlog_ir::rir::{HelperSplitSpec, KCliqueVariableOrder};
1341    use xlog_ir::{CompiledRule, ExecutionPlan, JoinType, ProjectExpr, RirMeta, RirNode, Scc};
1342    use xlog_stats::StatsManager;
1343
1344    const HEAVY_SKEW_RATIO: f64 = 10.0;
1345
1346    /// Description of a helper relation introduced by the pass.
1347    #[derive(Debug, Clone, PartialEq, Eq)]
1348    pub struct HelperRelationSpec {
1349        /// Predicate name allocated for the helper relation.
1350        pub name: String,
1351        /// Relation identifier allocated for the helper relation.
1352        pub rel_id: RelId,
1353        /// Output schema of the helper relation.
1354        pub schema: Schema,
1355        /// Pair of source relations extracted into the helper body.
1356        pub source_rels: [RelId; 2],
1357    }
1358
1359    struct JoinStep {
1360        left_keys: Vec<usize>,
1361        right_keys: Vec<usize>,
1362    }
1363
1364    struct LinearBody {
1365        leaves: Vec<RelId>,
1366        leaf_classes: Vec<Vec<u32>>,
1367        joins: Vec<JoinStep>,
1368        project: Vec<ProjectExpr>,
1369        final_classes: Vec<u32>,
1370    }
1371
1372    struct FlatJoin {
1373        leaves: Vec<RelId>,
1374        output_cols: Vec<usize>,
1375        equalities: Vec<(usize, usize)>,
1376    }
1377
1378    struct Candidate {
1379        pair_start: usize,
1380        helper_schema: Schema,
1381        helper_project: Vec<ProjectExpr>,
1382        helper_join_left_keys: Vec<usize>,
1383        helper_join_right_keys: Vec<usize>,
1384        exposed_classes: Vec<u32>,
1385    }
1386
1387    struct Rewrite {
1388        helper_body: RirNode,
1389        outer_body: RirNode,
1390        spec: HelperRelationSpec,
1391    }
1392
1393    #[derive(Clone, Copy)]
1394    struct KCliqueHelperEdge {
1395        slot: usize,
1396        rel: RelId,
1397        left: usize,
1398        right: usize,
1399    }
1400
1401    /// Rewrite eligible rules in-place and return the helper relations introduced.
1402    pub fn run<F>(
1403        plan: &mut ExecutionPlan,
1404        schemas: &HashMap<RelId, Schema>,
1405        stats: &StatsManager,
1406        mut allocate: F,
1407    ) -> Vec<HelperRelationSpec>
1408    where
1409        F: FnMut(Schema) -> (String, RelId),
1410    {
1411        let mut specs = Vec::new();
1412        for scc_idx in 0..plan.rules_by_scc.len() {
1413            let mut rule_idx = 0;
1414            while rule_idx < plan.rules_by_scc[scc_idx].len() {
1415                let rewrite = {
1416                    let rule = &plan.rules_by_scc[scc_idx][rule_idx];
1417                    try_rewrite_rule(rule, schemas, stats, &mut allocate)
1418                };
1419                if let Some(rewrite) = rewrite {
1420                    let helper_rule = CompiledRule {
1421                        head: rewrite.spec.name.clone(),
1422                        body: rewrite.helper_body,
1423                        meta: RirMeta::with_schema(rewrite.spec.schema.clone()),
1424                    };
1425                    plan.rules_by_scc[scc_idx].insert(rule_idx, helper_rule);
1426                    rule_idx += 1;
1427                    plan.rules_by_scc[scc_idx][rule_idx].body = rewrite.outer_body;
1428                    add_helper_to_scc(&mut plan.sccs, scc_idx, &rewrite.spec.name);
1429                    specs.push(rewrite.spec);
1430                }
1431                rule_idx += 1;
1432            }
1433        }
1434        specs
1435    }
1436
1437    /// Authorization 5 G_HELP_KC entry for K-clique plans that
1438    /// already carry planner-produced `HelperSplitSpec`s. The pass
1439    /// reuses the Phase-1 G4 helper-relation lifecycle: emit a helper
1440    /// rule before the consumer rule, allocate a compiler-owned helper
1441    /// relation, and rewrite the consumer to scan that helper.
1442    pub fn run_kclique_specs<F>(
1443        plan: &mut ExecutionPlan,
1444        schemas: &HashMap<RelId, Schema>,
1445        mut allocate: F,
1446    ) -> Vec<HelperRelationSpec>
1447    where
1448        F: FnMut(Schema) -> (String, RelId),
1449    {
1450        let mut specs = Vec::new();
1451        for scc_idx in 0..plan.rules_by_scc.len() {
1452            let mut rule_idx = 0;
1453            while rule_idx < plan.rules_by_scc[scc_idx].len() {
1454                let rewrite = {
1455                    let rule = &plan.rules_by_scc[scc_idx][rule_idx];
1456                    try_rewrite_kclique_rule(rule, schemas, &mut allocate)
1457                };
1458                if let Some(rewrite) = rewrite {
1459                    let helper_rule = CompiledRule {
1460                        head: rewrite.spec.name.clone(),
1461                        body: rewrite.helper_body,
1462                        meta: RirMeta::with_schema(rewrite.spec.schema.clone()),
1463                    };
1464                    plan.rules_by_scc[scc_idx].insert(rule_idx, helper_rule);
1465                    rule_idx += 1;
1466                    plan.rules_by_scc[scc_idx][rule_idx].body = rewrite.outer_body;
1467                    add_helper_to_scc(&mut plan.sccs, scc_idx, &rewrite.spec.name);
1468                    specs.push(rewrite.spec);
1469                }
1470                rule_idx += 1;
1471            }
1472        }
1473        specs
1474    }
1475
1476    fn add_helper_to_scc(sccs: &mut [Scc], scc_idx: usize, helper: &str) {
1477        if let Some(scc) = sccs.get_mut(scc_idx) {
1478            if !scc.predicates.iter().any(|p| p == helper) {
1479                scc.predicates.push(helper.to_string());
1480            }
1481        }
1482    }
1483
1484    fn try_rewrite_rule<F>(
1485        rule: &CompiledRule,
1486        schemas: &HashMap<RelId, Schema>,
1487        stats: &StatsManager,
1488        allocate: &mut F,
1489    ) -> Option<Rewrite>
1490    where
1491        F: FnMut(Schema) -> (String, RelId),
1492    {
1493        let linear = linearize_project_body(&rule.body, schemas)?;
1494        let candidate = choose_candidate(&linear, schemas, stats)?;
1495        let (helper_name, helper_rel) = allocate(candidate.helper_schema.clone());
1496        let helper_body = build_helper_body(&linear, &candidate);
1497        let outer_body = build_outer_body(&linear, &candidate, helper_rel)?;
1498        Some(Rewrite {
1499            helper_body,
1500            outer_body,
1501            spec: HelperRelationSpec {
1502                name: helper_name,
1503                rel_id: helper_rel,
1504                schema: candidate.helper_schema,
1505                source_rels: [
1506                    linear.leaves[candidate.pair_start],
1507                    linear.leaves[candidate.pair_start + 1],
1508                ],
1509            },
1510        })
1511    }
1512
1513    fn try_rewrite_kclique_rule<F>(
1514        rule: &CompiledRule,
1515        schemas: &HashMap<RelId, Schema>,
1516        allocate: &mut F,
1517    ) -> Option<Rewrite>
1518    where
1519        F: FnMut(Schema) -> (String, RelId),
1520    {
1521        let mut outer_body = rule.body.clone();
1522        let RirNode::MultiWayJoin {
1523            inputs, var_order, ..
1524        } = &mut outer_body
1525        else {
1526            return None;
1527        };
1528        let kclique = var_order.as_ref()?.kclique.as_ref()?;
1529        let spec = kclique.helper_split_specs.first()?;
1530        let (hot_left, hot_right, target) = kclique_helper_edges(inputs, kclique, spec)?;
1531        let helper_schema = schemas.get(&target.rel)?.clone();
1532        let (helper_name, helper_rel) = allocate(helper_schema.clone());
1533        let helper_body = build_kclique_helper_body(spec, hot_left, hot_right, target)?;
1534        *inputs.get_mut(target.slot)? = RirNode::Scan { rel: helper_rel };
1535        Some(Rewrite {
1536            helper_body,
1537            outer_body,
1538            spec: HelperRelationSpec {
1539                name: helper_name,
1540                rel_id: helper_rel,
1541                schema: helper_schema,
1542                source_rels: [hot_left.rel, hot_right.rel],
1543            },
1544        })
1545    }
1546
1547    fn kclique_helper_edges(
1548        inputs: &[RirNode],
1549        kclique: &KCliqueVariableOrder,
1550        spec: &HelperSplitSpec,
1551    ) -> Option<(KCliqueHelperEdge, KCliqueHelperEdge, KCliqueHelperEdge)> {
1552        let k = usize::from(kclique.k);
1553        let hot = usize::from(spec.variable);
1554        let mut hot_edges = Vec::new();
1555        let mut target = None;
1556        for &slot in &spec.edge_slots {
1557            let slot = usize::from(slot);
1558            let (left, right) = kclique_edge_pair(slot, k)?;
1559            let RirNode::Scan { rel } = inputs.get(slot)? else {
1560                return None;
1561            };
1562            let edge = KCliqueHelperEdge {
1563                slot,
1564                rel: *rel,
1565                left,
1566                right,
1567            };
1568            if left == hot || right == hot {
1569                hot_edges.push(edge);
1570            } else {
1571                target = Some(edge);
1572            }
1573        }
1574        if hot_edges.len() != 2 {
1575            return None;
1576        }
1577        Some((hot_edges[0], hot_edges[1], target?))
1578    }
1579
1580    fn build_kclique_helper_body(
1581        spec: &HelperSplitSpec,
1582        hot_left: KCliqueHelperEdge,
1583        hot_right: KCliqueHelperEdge,
1584        target: KCliqueHelperEdge,
1585    ) -> Option<RirNode> {
1586        let hot = usize::from(spec.variable);
1587        let target_left = target.left;
1588        let target_right = target.right;
1589        let first_other = kclique_other_endpoint(hot_left, hot)?;
1590        let second_other = kclique_other_endpoint(hot_right, hot)?;
1591        if ![first_other, second_other].contains(&target_left)
1592            || ![first_other, second_other].contains(&target_right)
1593        {
1594            return None;
1595        }
1596
1597        let first_scan = RirNode::Scan { rel: hot_left.rel };
1598        let second_scan = RirNode::Scan { rel: hot_right.rel };
1599        let target_scan = RirNode::Scan { rel: target.rel };
1600        let first_hot_col = kclique_endpoint_col(hot_left, hot)?;
1601        let second_hot_col = kclique_endpoint_col(hot_right, hot)?;
1602        let first_other_col = kclique_endpoint_col(hot_left, first_other)?;
1603        let second_other_col = 2 + kclique_endpoint_col(hot_right, second_other)?;
1604
1605        let target_left_in_join = if first_other == target_left {
1606            first_other_col
1607        } else {
1608            second_other_col
1609        };
1610        let target_right_in_join = if first_other == target_right {
1611            first_other_col
1612        } else {
1613            second_other_col
1614        };
1615        let target_left_col = kclique_endpoint_col(target, target_left)?;
1616        let target_right_col = kclique_endpoint_col(target, target_right)?;
1617
1618        let hot_join = RirNode::Join {
1619            left: Box::new(first_scan),
1620            right: Box::new(second_scan),
1621            left_keys: vec![first_hot_col],
1622            right_keys: vec![second_hot_col],
1623            join_type: JoinType::Inner,
1624        };
1625        let helper_join = RirNode::Join {
1626            left: Box::new(hot_join),
1627            right: Box::new(target_scan),
1628            left_keys: vec![target_left_in_join, target_right_in_join],
1629            right_keys: vec![target_left_col, target_right_col],
1630            join_type: JoinType::Inner,
1631        };
1632        Some(RirNode::Project {
1633            input: Box::new(helper_join),
1634            columns: vec![ProjectExpr::Column(4), ProjectExpr::Column(5)],
1635        })
1636    }
1637
1638    fn kclique_edge_pair(edge_idx: usize, k: usize) -> Option<(usize, usize)> {
1639        let mut idx = 0usize;
1640        for left in 0..k {
1641            for right in (left + 1)..k {
1642                if idx == edge_idx {
1643                    return Some((left, right));
1644                }
1645                idx += 1;
1646            }
1647        }
1648        None
1649    }
1650
1651    fn kclique_endpoint_col(edge: KCliqueHelperEdge, variable: usize) -> Option<usize> {
1652        if edge.left == variable {
1653            Some(0)
1654        } else if edge.right == variable {
1655            Some(1)
1656        } else {
1657            None
1658        }
1659    }
1660
1661    fn kclique_other_endpoint(edge: KCliqueHelperEdge, variable: usize) -> Option<usize> {
1662        if edge.left == variable {
1663            Some(edge.right)
1664        } else if edge.right == variable {
1665            Some(edge.left)
1666        } else {
1667            None
1668        }
1669    }
1670
1671    fn linearize_project_body(
1672        body: &RirNode,
1673        schemas: &HashMap<RelId, Schema>,
1674    ) -> Option<LinearBody> {
1675        let RirNode::Project { input, columns } = body else {
1676            return None;
1677        };
1678        let flat = collect_join_graph(input, schemas)?;
1679        if flat.leaves.len() < 6 {
1680            return None;
1681        }
1682        let mut offsets = Vec::with_capacity(flat.leaves.len());
1683        let mut total_cols = 0usize;
1684        for rel in &flat.leaves {
1685            offsets.push(total_cols);
1686            total_cols += schemas.get(rel)?.arity();
1687        }
1688        let mut uf = UnionFind::new(total_cols);
1689        for (left, right) in flat.equalities {
1690            if left >= total_cols || right >= total_cols {
1691                return None;
1692            }
1693            uf.union(left, right);
1694        }
1695        let mut leaf_classes: Vec<Vec<u32>> = Vec::with_capacity(flat.leaves.len());
1696        for (leaf_idx, rel) in flat.leaves.iter().enumerate() {
1697            let arity = schemas.get(rel)?.arity();
1698            let offset = offsets[leaf_idx];
1699            leaf_classes.push((0..arity).map(|col| uf.find(offset + col) as u32).collect());
1700        }
1701        let final_classes = flat
1702            .output_cols
1703            .iter()
1704            .map(|col| uf.find(*col) as u32)
1705            .collect();
1706        let joins = derive_left_deep_steps(&leaf_classes)?;
1707        Some(LinearBody {
1708            leaves: flat.leaves,
1709            leaf_classes,
1710            joins,
1711            project: columns.clone(),
1712            final_classes,
1713        })
1714    }
1715
1716    fn collect_join_graph(node: &RirNode, schemas: &HashMap<RelId, Schema>) -> Option<FlatJoin> {
1717        match node {
1718            RirNode::Scan { rel } => Some(FlatJoin {
1719                leaves: vec![*rel],
1720                output_cols: (0..schemas.get(rel)?.arity()).collect(),
1721                equalities: Vec::new(),
1722            }),
1723            RirNode::Join {
1724                left,
1725                right,
1726                left_keys,
1727                right_keys,
1728                join_type,
1729            } if *join_type == JoinType::Inner => {
1730                let left_flat = collect_join_graph(left, schemas)?;
1731                let right_flat = collect_join_graph(right, schemas)?;
1732                if left_keys.len() != right_keys.len() {
1733                    return None;
1734                }
1735                let right_shift = total_width(&left_flat.leaves, schemas)?;
1736                let mut leaves = left_flat.leaves;
1737                leaves.extend(right_flat.leaves);
1738                let right_output_cols: Vec<usize> = right_flat
1739                    .output_cols
1740                    .iter()
1741                    .map(|col| col + right_shift)
1742                    .collect();
1743                let mut equalities = left_flat.equalities;
1744                equalities.extend(
1745                    right_flat
1746                        .equalities
1747                        .iter()
1748                        .map(|(left, right)| (left + right_shift, right + right_shift)),
1749                );
1750                for (&left_key, &right_key) in left_keys.iter().zip(right_keys.iter()) {
1751                    equalities.push((
1752                        *left_flat.output_cols.get(left_key)?,
1753                        *right_output_cols.get(right_key)?,
1754                    ));
1755                }
1756                let mut output_cols = left_flat.output_cols;
1757                output_cols.extend(right_output_cols);
1758                Some(FlatJoin {
1759                    leaves,
1760                    output_cols,
1761                    equalities,
1762                })
1763            }
1764            _ => None,
1765        }
1766    }
1767
1768    fn total_width(leaves: &[RelId], schemas: &HashMap<RelId, Schema>) -> Option<usize> {
1769        leaves
1770            .iter()
1771            .map(|rel| schemas.get(rel).map(Schema::arity))
1772            .try_fold(0usize, |acc, width| width.map(|width| acc + width))
1773    }
1774
1775    fn derive_left_deep_steps(leaf_classes: &[Vec<u32>]) -> Option<Vec<JoinStep>> {
1776        let mut joins = Vec::with_capacity(leaf_classes.len().saturating_sub(1));
1777        let mut current = leaf_classes.first()?.clone();
1778        for classes in leaf_classes.iter().skip(1) {
1779            let mut left_keys = Vec::new();
1780            let mut right_keys = Vec::new();
1781            for (right_col, class) in classes.iter().enumerate() {
1782                if let Some(left_col) = current
1783                    .iter()
1784                    .position(|current_class| current_class == class)
1785                {
1786                    left_keys.push(left_col);
1787                    right_keys.push(right_col);
1788                }
1789            }
1790            if left_keys.is_empty() {
1791                return None;
1792            }
1793            joins.push(JoinStep {
1794                left_keys,
1795                right_keys,
1796            });
1797            current.extend(classes.iter().copied());
1798        }
1799        Some(joins)
1800    }
1801
1802    fn choose_candidate(
1803        linear: &LinearBody,
1804        schemas: &HashMap<RelId, Schema>,
1805        stats: &StatsManager,
1806    ) -> Option<Candidate> {
1807        for pair_start in 3..linear.leaves.len().saturating_sub(1) {
1808            let candidate = build_candidate(linear, schemas, pair_start)?;
1809            if skew_ratio_for_candidate(linear, stats, &candidate) >= HEAVY_SKEW_RATIO {
1810                return Some(candidate);
1811            }
1812        }
1813        None
1814    }
1815
1816    fn build_candidate(
1817        linear: &LinearBody,
1818        schemas: &HashMap<RelId, Schema>,
1819        pair_start: usize,
1820    ) -> Option<Candidate> {
1821        let left_rel = linear.leaves[pair_start];
1822        let right_rel = linear.leaves[pair_start + 1];
1823        let left_schema = schemas.get(&left_rel)?;
1824        let right_schema = schemas.get(&right_rel)?;
1825        let internal_step = linear.joins.get(pair_start)?;
1826        let mut helper_left_keys = Vec::new();
1827        let mut helper_right_keys = Vec::new();
1828        for (&left_key, &right_key) in internal_step
1829            .left_keys
1830            .iter()
1831            .zip(internal_step.right_keys.iter())
1832        {
1833            let class = class_at_state(linear, pair_start + 1, left_key)?;
1834            let left_col = linear.leaf_classes[pair_start]
1835                .iter()
1836                .position(|c| *c == class)?;
1837            helper_left_keys.push(left_col);
1838            helper_right_keys.push(right_key);
1839        }
1840        let internal: HashSet<u32> = helper_left_keys
1841            .iter()
1842            .map(|col| linear.leaf_classes[pair_start][*col])
1843            .collect();
1844        let outside = outside_classes(linear, pair_start);
1845        let output = projected_classes(linear)?;
1846        let mut exposed_classes = Vec::new();
1847        let mut helper_project = Vec::new();
1848        let mut helper_columns = Vec::new();
1849        for (col, class) in linear.leaf_classes[pair_start].iter().copied().enumerate() {
1850            if !internal.contains(&class)
1851                && (outside.contains(&class) || output.contains(&class))
1852                && !exposed_classes.contains(&class)
1853            {
1854                exposed_classes.push(class);
1855                helper_project.push(ProjectExpr::Column(col));
1856                let ty = left_schema.column_type(col).unwrap_or(ScalarType::U32);
1857                helper_columns.push((format!("c{}", helper_columns.len()), ty));
1858            }
1859        }
1860        let right_offset = left_schema.arity();
1861        for (col, class) in linear.leaf_classes[pair_start + 1]
1862            .iter()
1863            .copied()
1864            .enumerate()
1865        {
1866            if !internal.contains(&class)
1867                && (outside.contains(&class) || output.contains(&class))
1868                && !exposed_classes.contains(&class)
1869            {
1870                exposed_classes.push(class);
1871                helper_project.push(ProjectExpr::Column(right_offset + col));
1872                let ty = right_schema.column_type(col).unwrap_or(ScalarType::U32);
1873                helper_columns.push((format!("c{}", helper_columns.len()), ty));
1874            }
1875        }
1876        if exposed_classes.len() != 2 {
1877            return None;
1878        }
1879        Some(Candidate {
1880            pair_start,
1881            helper_schema: Schema::new(helper_columns),
1882            helper_project,
1883            helper_join_left_keys: helper_left_keys,
1884            helper_join_right_keys: helper_right_keys,
1885            exposed_classes,
1886        })
1887    }
1888
1889    fn class_at_state(linear: &LinearBody, leaf_count: usize, col: usize) -> Option<u32> {
1890        let mut idx = col;
1891        for leaf_idx in 0..leaf_count {
1892            let classes = &linear.leaf_classes[leaf_idx];
1893            if idx < classes.len() {
1894                return Some(classes[idx]);
1895            }
1896            idx -= classes.len();
1897        }
1898        None
1899    }
1900
1901    fn outside_classes(linear: &LinearBody, pair_start: usize) -> HashSet<u32> {
1902        linear
1903            .leaf_classes
1904            .iter()
1905            .enumerate()
1906            .filter(|(idx, _)| *idx != pair_start && *idx != pair_start + 1)
1907            .flat_map(|(_, classes)| classes.iter().copied())
1908            .collect()
1909    }
1910
1911    fn projected_classes(linear: &LinearBody) -> Option<HashSet<u32>> {
1912        let mut out = HashSet::new();
1913        for expr in &linear.project {
1914            let ProjectExpr::Column(col) = expr else {
1915                return None;
1916            };
1917            out.insert(*linear.final_classes.get(*col)?);
1918        }
1919        Some(out)
1920    }
1921
1922    fn skew_ratio_for_candidate(
1923        linear: &LinearBody,
1924        stats: &StatsManager,
1925        candidate: &Candidate,
1926    ) -> f64 {
1927        let rel = linear.leaves[candidate.pair_start];
1928        let Some(rel_stats) = stats.get_relation_stats(rel) else {
1929            return 0.0;
1930        };
1931        let mut ratio: f64 = 0.0;
1932        for (col, class) in linear.leaf_classes[candidate.pair_start]
1933            .iter()
1934            .copied()
1935            .enumerate()
1936        {
1937            if !candidate.exposed_classes.contains(&class) {
1938                continue;
1939            }
1940            let Some(col_stats) = rel_stats.get_column(col) else {
1941                continue;
1942            };
1943            if col_stats.distinct_estimate == 0 {
1944                continue;
1945            }
1946            ratio = ratio.max(rel_stats.cardinality as f64 / col_stats.distinct_estimate as f64);
1947        }
1948        ratio
1949    }
1950
1951    fn build_helper_body(linear: &LinearBody, candidate: &Candidate) -> RirNode {
1952        let left = RirNode::Scan {
1953            rel: linear.leaves[candidate.pair_start],
1954        };
1955        let right = RirNode::Scan {
1956            rel: linear.leaves[candidate.pair_start + 1],
1957        };
1958        RirNode::Project {
1959            input: Box::new(RirNode::Join {
1960                left: Box::new(left),
1961                right: Box::new(right),
1962                left_keys: candidate.helper_join_left_keys.clone(),
1963                right_keys: candidate.helper_join_right_keys.clone(),
1964                join_type: JoinType::Inner,
1965            }),
1966            columns: candidate.helper_project.clone(),
1967        }
1968    }
1969
1970    fn build_outer_body(
1971        linear: &LinearBody,
1972        candidate: &Candidate,
1973        helper_rel: RelId,
1974    ) -> Option<RirNode> {
1975        let mut node = RirNode::Scan {
1976            rel: linear.leaves[0],
1977        };
1978        let mut classes = linear.leaf_classes[0].clone();
1979        for leaf_idx in 1..candidate.pair_start {
1980            let step = &linear.joins[leaf_idx - 1];
1981            node = RirNode::Join {
1982                left: Box::new(node),
1983                right: Box::new(RirNode::Scan {
1984                    rel: linear.leaves[leaf_idx],
1985                }),
1986                left_keys: step.left_keys.clone(),
1987                right_keys: step.right_keys.clone(),
1988                join_type: JoinType::Inner,
1989            };
1990            classes.extend(linear.leaf_classes[leaf_idx].iter().copied());
1991        }
1992        let prefix_step = &linear.joins[candidate.pair_start - 1];
1993        let mut helper_right_keys = Vec::new();
1994        for &rk in &prefix_step.right_keys {
1995            let class = linear.leaf_classes[candidate.pair_start][rk];
1996            helper_right_keys.push(candidate.exposed_classes.iter().position(|c| *c == class)?);
1997        }
1998        node = RirNode::Join {
1999            left: Box::new(node),
2000            right: Box::new(RirNode::Scan { rel: helper_rel }),
2001            left_keys: prefix_step.left_keys.clone(),
2002            right_keys: helper_right_keys,
2003            join_type: JoinType::Inner,
2004        };
2005        classes.extend(candidate.exposed_classes.iter().copied());
2006        for leaf_idx in candidate.pair_start + 2..linear.leaves.len() {
2007            let step = &linear.joins[leaf_idx - 1];
2008            let mut left_keys = Vec::new();
2009            for &lk in &step.left_keys {
2010                let class = class_at_state(linear, leaf_idx, lk)?;
2011                left_keys.push(classes.iter().position(|c| *c == class)?);
2012            }
2013            node = RirNode::Join {
2014                left: Box::new(node),
2015                right: Box::new(RirNode::Scan {
2016                    rel: linear.leaves[leaf_idx],
2017                }),
2018                left_keys,
2019                right_keys: step.right_keys.clone(),
2020                join_type: JoinType::Inner,
2021            };
2022            classes.extend(linear.leaf_classes[leaf_idx].iter().copied());
2023        }
2024        let mut project = Vec::with_capacity(linear.project.len());
2025        for expr in &linear.project {
2026            let ProjectExpr::Column(col) = expr else {
2027                return None;
2028            };
2029            let class = *linear.final_classes.get(*col)?;
2030            let mapped = classes.iter().position(|c| *c == class)?;
2031            project.push(ProjectExpr::Column(mapped));
2032        }
2033        Some(RirNode::Project {
2034            input: Box::new(node),
2035            columns: project,
2036        })
2037    }
2038
2039    struct UnionFind {
2040        parent: Vec<usize>,
2041    }
2042
2043    impl UnionFind {
2044        fn new(len: usize) -> Self {
2045            Self {
2046                parent: (0..len).collect(),
2047            }
2048        }
2049
2050        fn find(&mut self, x: usize) -> usize {
2051            let p = self.parent[x];
2052            if p == x {
2053                x
2054            } else {
2055                let root = self.find(p);
2056                self.parent[x] = root;
2057                root
2058            }
2059        }
2060
2061        fn union(&mut self, a: usize, b: usize) {
2062            let ra = self.find(a);
2063            let rb = self.find(b);
2064            if ra != rb {
2065                self.parent[rb] = ra;
2066            }
2067        }
2068    }
2069}
2070
2071#[path = "optimizer/stream_schedule_pass.rs"]
2072pub mod stream_schedule_pass;
2073
2074#[cfg(test)]
2075mod helper_split_pass_tests {
2076    use std::collections::HashMap;
2077
2078    use super::helper_split_pass;
2079    use xlog_core::{RelId, ScalarType, Schema};
2080    use xlog_ir::{CompiledRule, ExecutionPlan, JoinType, ProjectExpr, RirMeta, RirNode, Scc};
2081    use xlog_stats::{ColumnStats, StatsManager};
2082
2083    fn edge_schema() -> Schema {
2084        Schema::new(vec![
2085            ("c0".to_string(), ScalarType::U32),
2086            ("c1".to_string(), ScalarType::U32),
2087        ])
2088    }
2089
2090    fn helper_schema() -> Schema {
2091        Schema::new(vec![
2092            ("c0".to_string(), ScalarType::U32),
2093            ("c1".to_string(), ScalarType::U32),
2094        ])
2095    }
2096
2097    fn schemas() -> HashMap<RelId, Schema> {
2098        (0..6)
2099            .map(|idx| (RelId(idx), edge_schema()))
2100            .collect::<HashMap<_, _>>()
2101    }
2102
2103    fn left_deep_fixture_body() -> RirNode {
2104        let ab_bc = RirNode::Join {
2105            left: Box::new(RirNode::Scan { rel: RelId(0) }),
2106            right: Box::new(RirNode::Scan { rel: RelId(1) }),
2107            left_keys: vec![1],
2108            right_keys: vec![0],
2109            join_type: JoinType::Inner,
2110        };
2111        let with_cd = RirNode::Join {
2112            left: Box::new(ab_bc),
2113            right: Box::new(RirNode::Scan { rel: RelId(2) }),
2114            left_keys: vec![3],
2115            right_keys: vec![0],
2116            join_type: JoinType::Inner,
2117        };
2118        let with_de = RirNode::Join {
2119            left: Box::new(with_cd),
2120            right: Box::new(RirNode::Scan { rel: RelId(3) }),
2121            left_keys: vec![5],
2122            right_keys: vec![0],
2123            join_type: JoinType::Inner,
2124        };
2125        let with_ef = RirNode::Join {
2126            left: Box::new(with_de),
2127            right: Box::new(RirNode::Scan { rel: RelId(4) }),
2128            left_keys: vec![7],
2129            right_keys: vec![0],
2130            join_type: JoinType::Inner,
2131        };
2132        let with_af = RirNode::Join {
2133            left: Box::new(with_ef),
2134            right: Box::new(RirNode::Scan { rel: RelId(5) }),
2135            left_keys: vec![0, 9],
2136            right_keys: vec![0, 1],
2137            join_type: JoinType::Inner,
2138        };
2139        RirNode::Project {
2140            input: Box::new(with_af),
2141            columns: vec![
2142                ProjectExpr::Column(0),
2143                ProjectExpr::Column(1),
2144                ProjectExpr::Column(3),
2145                ProjectExpr::Column(5),
2146                ProjectExpr::Column(9),
2147            ],
2148        }
2149    }
2150
2151    fn plan() -> ExecutionPlan {
2152        ExecutionPlan {
2153            sccs: vec![Scc {
2154                id: 0,
2155                predicates: vec!["out".to_string()],
2156                is_recursive: false,
2157            }],
2158            strata: vec![],
2159            rules_by_scc: vec![vec![CompiledRule {
2160                head: "out".to_string(),
2161                body: left_deep_fixture_body(),
2162                meta: RirMeta::with_schema(Schema::new(vec![
2163                    ("a".to_string(), ScalarType::U32),
2164                    ("b".to_string(), ScalarType::U32),
2165                    ("c".to_string(), ScalarType::U32),
2166                    ("d".to_string(), ScalarType::U32),
2167                    ("f".to_string(), ScalarType::U32),
2168                ])),
2169            }]],
2170            est_memory_peak: 0,
2171        }
2172    }
2173
2174    fn stats_for_de(distinct_d: u64) -> StatsManager {
2175        let mut stats = StatsManager::new();
2176        for idx in 0..6 {
2177            stats.register_relation(RelId(idx));
2178            stats.update_cardinality(RelId(idx), 8192);
2179        }
2180        let mut d_col = ColumnStats::new(0, ScalarType::U32);
2181        d_col.update_distinct(distinct_d);
2182        stats.add_column_stats(RelId(3), d_col);
2183        stats
2184    }
2185
2186    fn contains_scan(node: &RirNode, rel: RelId) -> bool {
2187        match node {
2188            RirNode::Scan { rel: scan_rel } => *scan_rel == rel,
2189            RirNode::Join { left, right, .. } | RirNode::ChainJoin { left, right, .. } => {
2190                contains_scan(left, rel) || contains_scan(right, rel)
2191            }
2192            RirNode::Project { input, .. }
2193            | RirNode::Filter { input, .. }
2194            | RirNode::Distinct { input, .. }
2195            | RirNode::GroupBy { input, .. } => contains_scan(input, rel),
2196            RirNode::Union { inputs } => inputs.iter().any(|input| contains_scan(input, rel)),
2197            RirNode::Diff { left, right } => contains_scan(left, rel) || contains_scan(right, rel),
2198            RirNode::Fixpoint {
2199                base, recursive, ..
2200            } => contains_scan(base, rel) || contains_scan(recursive, rel),
2201            RirNode::MultiWayJoin { inputs, .. } => {
2202                inputs.iter().any(|input| contains_scan(input, rel))
2203            }
2204            RirNode::TensorMaskedJoin { rel_index, .. } => {
2205                rel_index.iter().any(|(input_rel, _)| *input_rel == rel)
2206            }
2207            RirNode::Unit => false,
2208        }
2209    }
2210
2211    #[test]
2212    fn helper_split_extracts_buried_pair() {
2213        let mut plan = plan();
2214        let schemas = schemas();
2215        let stats = stats_for_de(1);
2216        let specs = helper_split_pass::run(&mut plan, &schemas, &stats, |_| {
2217            ("__w37_helper_6".to_string(), RelId(6))
2218        });
2219
2220        assert_eq!(specs.len(), 1);
2221        assert_eq!(specs[0].name, "__w37_helper_6");
2222        assert_eq!(specs[0].rel_id, RelId(6));
2223        assert_eq!(specs[0].schema, helper_schema());
2224        assert_eq!(specs[0].source_rels, [RelId(3), RelId(4)]);
2225        assert_eq!(plan.rules_by_scc[0].len(), 2);
2226        assert_eq!(plan.rules_by_scc[0][0].head, "__w37_helper_6");
2227        assert_eq!(plan.rules_by_scc[0][1].head, "out");
2228        assert!(contains_scan(&plan.rules_by_scc[0][1].body, RelId(6)));
2229        assert!(plan.sccs[0]
2230            .predicates
2231            .iter()
2232            .any(|predicate| predicate == "__w37_helper_6"));
2233    }
2234
2235    #[test]
2236    fn helper_split_ignores_flat_distribution() {
2237        let mut plan = plan();
2238        let schemas = schemas();
2239        let stats = stats_for_de(8192);
2240        let specs = helper_split_pass::run(&mut plan, &schemas, &stats, |_| {
2241            ("__w37_helper_6".to_string(), RelId(6))
2242        });
2243
2244        assert!(specs.is_empty());
2245        assert_eq!(plan.rules_by_scc[0].len(), 1);
2246        assert!(!contains_scan(&plan.rules_by_scc[0][0].body, RelId(6)));
2247    }
2248}
2249
2250/// W2.2 — selectivity-driven body rewriters for triangle and
2251/// 4-cycle canonical lowered shapes. `pub(super)` so
2252/// `selectivity_pass::run` can dispatch into them.
2253mod reorder {
2254    use std::collections::HashMap;
2255    use xlog_core::RelId;
2256    use xlog_ir::rir::ProjectExpr;
2257    use xlog_ir::{JoinType, RirNode};
2258    use xlog_stats::StatsManager;
2259
2260    fn ac3(atom: u8, col: u8) -> u8 {
2261        atom * 2 + col
2262    }
2263    fn ac4(atom: u8, col: u8) -> u8 {
2264        atom * 2 + col
2265    }
2266    fn uf_find_n<const N: usize>(parent: &mut [u8; N], x: u8) -> u8 {
2267        let mut root = x;
2268        while parent[root as usize] != root {
2269            root = parent[root as usize];
2270        }
2271        let mut cur = x;
2272        while parent[cur as usize] != root {
2273            let next = parent[cur as usize];
2274            parent[cur as usize] = root;
2275            cur = next;
2276        }
2277        root
2278    }
2279    fn uf_union_n<const N: usize>(parent: &mut [u8; N], a: u8, b: u8) {
2280        let ra = uf_find_n(parent, a);
2281        let rb = uf_find_n(parent, b);
2282        if ra != rb {
2283            parent[rb as usize] = ra;
2284        }
2285    }
2286
2287    fn populated_card(stats: &StatsManager, rel: RelId) -> Option<u64> {
2288        stats
2289            .get_relation_stats(rel)
2290            .map(|s| s.cardinality)
2291            .filter(|c| *c > 0)
2292    }
2293
2294    // ---------------------------------------------------------
2295    // Triangle rewriter
2296    // ---------------------------------------------------------
2297
2298    struct TriangleSemantics {
2299        rel_xy: RelId,
2300        rel_yz: RelId,
2301        rel_xz: RelId,
2302    }
2303
2304    fn match_and_infer_triangle(body: &RirNode) -> Option<TriangleSemantics> {
2305        let RirNode::Project {
2306            input: outer_input,
2307            columns,
2308        } = body
2309        else {
2310            return None;
2311        };
2312        let RirNode::Join {
2313            left: l1,
2314            right: r1,
2315            left_keys: lk1,
2316            right_keys: rk1,
2317            join_type: jt1,
2318        } = outer_input.as_ref()
2319        else {
2320            return None;
2321        };
2322        if !matches!(jt1, JoinType::Inner) {
2323            return None;
2324        }
2325        let RirNode::Scan { rel: rel_third } = r1.as_ref() else {
2326            return None;
2327        };
2328        let RirNode::Join {
2329            left: l2,
2330            right: r2,
2331            left_keys: lk2,
2332            right_keys: rk2,
2333            join_type: jt2,
2334        } = l1.as_ref()
2335        else {
2336            return None;
2337        };
2338        if !matches!(jt2, JoinType::Inner) {
2339            return None;
2340        }
2341        let RirNode::Scan { rel: rel_inner_l } = l2.as_ref() else {
2342            return None;
2343        };
2344        let RirNode::Scan { rel: rel_inner_r } = r2.as_ref() else {
2345            return None;
2346        };
2347        if lk2.len() != 1 || rk2.len() != 1 || lk1.len() != 2 || rk1.len() != 2 {
2348            return None;
2349        }
2350        if columns.len() != 3 {
2351            return None;
2352        }
2353        if lk2[0] >= 2 || rk2[0] >= 2 {
2354            return None;
2355        }
2356        if lk1.iter().any(|k| *k >= 4) || rk1.iter().any(|k| *k >= 2) {
2357            return None;
2358        }
2359
2360        let mut parent = [0u8, 1, 2, 3, 4, 5];
2361        uf_union_n::<6>(&mut parent, ac3(0, lk2[0] as u8), ac3(1, rk2[0] as u8));
2362        for i in 0..2 {
2363            let inner_ac = match lk1[i] {
2364                0 => (0u8, 0u8),
2365                1 => (0, 1),
2366                2 => (1, 0),
2367                3 => (1, 1),
2368                _ => return None,
2369            };
2370            uf_union_n::<6>(
2371                &mut parent,
2372                ac3(inner_ac.0, inner_ac.1),
2373                ac3(2, rk1[i] as u8),
2374            );
2375        }
2376        let roots: [u8; 6] = std::array::from_fn(|i| uf_find_n::<6>(&mut parent, i as u8));
2377        let mut counts: HashMap<u8, u8> = HashMap::new();
2378        for r in &roots {
2379            *counts.entry(*r).or_insert(0) += 1;
2380        }
2381        if counts.len() != 3 || counts.values().any(|c| *c != 2) {
2382            return None;
2383        }
2384        let mut head_classes: [u8; 3] = [0; 3];
2385        for (i, pc) in columns.iter().enumerate() {
2386            let ProjectExpr::Column(k) = pc else {
2387                return None;
2388            };
2389            let outer_ac = match *k {
2390                0 => (0u8, 0u8),
2391                1 => (0, 1),
2392                2 => (1, 0),
2393                3 => (1, 1),
2394                4 => (2, 0),
2395                5 => (2, 1),
2396                _ => return None,
2397            };
2398            head_classes[i] = uf_find_n::<6>(&mut parent, ac3(outer_ac.0, outer_ac.1));
2399        }
2400        if head_classes[0] == head_classes[1]
2401            || head_classes[0] == head_classes[2]
2402            || head_classes[1] == head_classes[2]
2403        {
2404            return None;
2405        }
2406        let x_class = head_classes[0];
2407        let y_class = head_classes[1];
2408        let z_class = head_classes[2];
2409        let atom_classes = |a: u8| (roots[ac3(a, 0) as usize], roots[ac3(a, 1) as usize]);
2410        let atom_rels = [*rel_inner_l, *rel_inner_r, *rel_third];
2411        let mut rel_xy = None;
2412        let mut rel_yz = None;
2413        let mut rel_xz = None;
2414        for atom_idx in 0..3u8 {
2415            let (c0, c1) = atom_classes(atom_idx);
2416            let bx = c0 == x_class || c1 == x_class;
2417            let by = c0 == y_class || c1 == y_class;
2418            let bz = c0 == z_class || c1 == z_class;
2419            match (bx, by, bz) {
2420                (true, true, false) => rel_xy = Some(atom_rels[atom_idx as usize]),
2421                (false, true, true) => rel_yz = Some(atom_rels[atom_idx as usize]),
2422                (true, false, true) => rel_xz = Some(atom_rels[atom_idx as usize]),
2423                _ => return None,
2424            }
2425        }
2426        Some(TriangleSemantics {
2427            rel_xy: rel_xy?,
2428            rel_yz: rel_yz?,
2429            rel_xz: rel_xz?,
2430        })
2431    }
2432
2433    #[derive(Clone, Copy, Debug, PartialEq, Eq)]
2434    #[allow(clippy::enum_variant_names)]
2435    enum TriangleInnerPair {
2436        YShared,
2437        XShared,
2438        ZShared,
2439    }
2440
2441    fn build_triangle_body(s: &TriangleSemantics, inner_pair: TriangleInnerPair) -> RirNode {
2442        let mk_scan = |r: RelId| RirNode::Scan { rel: r };
2443        match inner_pair {
2444            TriangleInnerPair::YShared => {
2445                let inner = RirNode::Join {
2446                    left: Box::new(mk_scan(s.rel_xy)),
2447                    right: Box::new(mk_scan(s.rel_yz)),
2448                    left_keys: vec![1],
2449                    right_keys: vec![0],
2450                    join_type: JoinType::Inner,
2451                };
2452                let outer = RirNode::Join {
2453                    left: Box::new(inner),
2454                    right: Box::new(mk_scan(s.rel_xz)),
2455                    left_keys: vec![0, 3],
2456                    right_keys: vec![0, 1],
2457                    join_type: JoinType::Inner,
2458                };
2459                RirNode::Project {
2460                    input: Box::new(outer),
2461                    columns: vec![
2462                        ProjectExpr::Column(0),
2463                        ProjectExpr::Column(1),
2464                        ProjectExpr::Column(3),
2465                    ],
2466                }
2467            }
2468            TriangleInnerPair::XShared => {
2469                let inner = RirNode::Join {
2470                    left: Box::new(mk_scan(s.rel_xy)),
2471                    right: Box::new(mk_scan(s.rel_xz)),
2472                    left_keys: vec![0],
2473                    right_keys: vec![0],
2474                    join_type: JoinType::Inner,
2475                };
2476                let outer = RirNode::Join {
2477                    left: Box::new(inner),
2478                    right: Box::new(mk_scan(s.rel_yz)),
2479                    left_keys: vec![1, 3],
2480                    right_keys: vec![0, 1],
2481                    join_type: JoinType::Inner,
2482                };
2483                RirNode::Project {
2484                    input: Box::new(outer),
2485                    columns: vec![
2486                        ProjectExpr::Column(0),
2487                        ProjectExpr::Column(1),
2488                        ProjectExpr::Column(3),
2489                    ],
2490                }
2491            }
2492            TriangleInnerPair::ZShared => {
2493                let inner = RirNode::Join {
2494                    left: Box::new(mk_scan(s.rel_xz)),
2495                    right: Box::new(mk_scan(s.rel_yz)),
2496                    left_keys: vec![1],
2497                    right_keys: vec![1],
2498                    join_type: JoinType::Inner,
2499                };
2500                let outer = RirNode::Join {
2501                    left: Box::new(inner),
2502                    right: Box::new(mk_scan(s.rel_xy)),
2503                    left_keys: vec![0, 2],
2504                    right_keys: vec![0, 1],
2505                    join_type: JoinType::Inner,
2506                };
2507                RirNode::Project {
2508                    input: Box::new(outer),
2509                    columns: vec![
2510                        ProjectExpr::Column(0),
2511                        ProjectExpr::Column(2),
2512                        ProjectExpr::Column(3),
2513                    ],
2514                }
2515            }
2516        }
2517    }
2518
2519    pub fn try_reorder_triangle(body: &RirNode, stats: &StatsManager) -> Option<RirNode> {
2520        let s = match_and_infer_triangle(body)?;
2521        let _ = (
2522            populated_card(stats, s.rel_xy)?,
2523            populated_card(stats, s.rel_yz)?,
2524            populated_card(stats, s.rel_xz)?,
2525        );
2526        let est_y = stats.estimate_join_cardinality(s.rel_xy, s.rel_yz, &[1], &[0]);
2527        let est_x = stats.estimate_join_cardinality(s.rel_xy, s.rel_xz, &[0], &[0]);
2528        let est_z = stats.estimate_join_cardinality(s.rel_yz, s.rel_xz, &[1], &[1]);
2529        let mut best = (TriangleInnerPair::YShared, est_y);
2530        if est_x < best.1 {
2531            best = (TriangleInnerPair::XShared, est_x);
2532        }
2533        if est_z < best.1 {
2534            best = (TriangleInnerPair::ZShared, est_z);
2535        }
2536        let candidate = build_triangle_body(&s, best.0);
2537        // Skip when the candidate is structurally identical to
2538        // the input (no-op rewrite). RirNode doesn't impl
2539        // PartialEq, so compare via Debug — bodies are small
2540        // (≤ 6 Scans + 2 Joins + 1 Project) so the cost is
2541        // negligible relative to the optimizer's broader work.
2542        if format!("{:?}", candidate) == format!("{:?}", body) {
2543            return None;
2544        }
2545        Some(candidate)
2546    }
2547
2548    // ---------------------------------------------------------
2549    // 4-cycle rewriter
2550    // ---------------------------------------------------------
2551
2552    struct Cycle4Semantics {
2553        rel_wx: RelId,
2554        rel_xy: RelId,
2555        rel_yz: RelId,
2556        rel_zw: RelId,
2557    }
2558
2559    fn match_and_infer_4cycle(body: &RirNode) -> Option<Cycle4Semantics> {
2560        let RirNode::Project {
2561            input: outer_input,
2562            columns,
2563        } = body
2564        else {
2565            return None;
2566        };
2567        let RirNode::Join {
2568            left: outer_l,
2569            right: outer_r,
2570            left_keys: olk,
2571            right_keys: ork,
2572            join_type: ojt,
2573        } = outer_input.as_ref()
2574        else {
2575            return None;
2576        };
2577        if !matches!(ojt, JoinType::Inner) {
2578            return None;
2579        }
2580        let RirNode::Join {
2581            left: ll,
2582            right: lr,
2583            left_keys: ilk_l,
2584            right_keys: irk_l,
2585            join_type: ijt_l,
2586        } = outer_l.as_ref()
2587        else {
2588            return None;
2589        };
2590        if !matches!(ijt_l, JoinType::Inner) {
2591            return None;
2592        }
2593        let RirNode::Scan { rel: rel_ll } = ll.as_ref() else {
2594            return None;
2595        };
2596        let RirNode::Scan { rel: rel_lr } = lr.as_ref() else {
2597            return None;
2598        };
2599        let RirNode::Join {
2600            left: rl,
2601            right: rr,
2602            left_keys: ilk_r,
2603            right_keys: irk_r,
2604            join_type: ijt_r,
2605        } = outer_r.as_ref()
2606        else {
2607            return None;
2608        };
2609        if !matches!(ijt_r, JoinType::Inner) {
2610            return None;
2611        }
2612        let RirNode::Scan { rel: rel_rl } = rl.as_ref() else {
2613            return None;
2614        };
2615        let RirNode::Scan { rel: rel_rr } = rr.as_ref() else {
2616            return None;
2617        };
2618        if ilk_l.len() != 1 || irk_l.len() != 1 || ilk_r.len() != 1 || irk_r.len() != 1 {
2619            return None;
2620        }
2621        if olk.len() != 2 || ork.len() != 2 || columns.len() != 4 {
2622            return None;
2623        }
2624        if ilk_l[0] >= 2 || irk_l[0] >= 2 || ilk_r[0] >= 2 || irk_r[0] >= 2 {
2625            return None;
2626        }
2627        if olk.iter().any(|k| *k >= 4) || ork.iter().any(|k| *k >= 4) {
2628            return None;
2629        }
2630
2631        let mut parent = [0u8, 1, 2, 3, 4, 5, 6, 7];
2632        uf_union_n::<8>(&mut parent, ac4(0, ilk_l[0] as u8), ac4(1, irk_l[0] as u8));
2633        uf_union_n::<8>(&mut parent, ac4(2, ilk_r[0] as u8), ac4(3, irk_r[0] as u8));
2634        for i in 0..2 {
2635            let l_ac = match olk[i] {
2636                0 => (0u8, 0u8),
2637                1 => (0, 1),
2638                2 => (1, 0),
2639                3 => (1, 1),
2640                _ => return None,
2641            };
2642            let r_ac = match ork[i] {
2643                0 => (2u8, 0u8),
2644                1 => (2, 1),
2645                2 => (3, 0),
2646                3 => (3, 1),
2647                _ => return None,
2648            };
2649            uf_union_n::<8>(&mut parent, ac4(l_ac.0, l_ac.1), ac4(r_ac.0, r_ac.1));
2650        }
2651        let roots: [u8; 8] = std::array::from_fn(|i| uf_find_n::<8>(&mut parent, i as u8));
2652        let mut counts: HashMap<u8, u8> = HashMap::new();
2653        for r in &roots {
2654            *counts.entry(*r).or_insert(0) += 1;
2655        }
2656        if counts.len() != 4 || counts.values().any(|c| *c != 2) {
2657            return None;
2658        }
2659
2660        let mut head_classes: [u8; 4] = [0; 4];
2661        for (i, pc) in columns.iter().enumerate() {
2662            let ProjectExpr::Column(k) = pc else {
2663                return None;
2664            };
2665            let ac = match *k {
2666                0 => (0u8, 0u8),
2667                1 => (0, 1),
2668                2 => (1, 0),
2669                3 => (1, 1),
2670                4 => (2, 0),
2671                5 => (2, 1),
2672                6 => (3, 0),
2673                7 => (3, 1),
2674                _ => return None,
2675            };
2676            head_classes[i] = uf_find_n::<8>(&mut parent, ac4(ac.0, ac.1));
2677        }
2678        for i in 0..4 {
2679            for j in (i + 1)..4 {
2680                if head_classes[i] == head_classes[j] {
2681                    return None;
2682                }
2683            }
2684        }
2685        let w_class = head_classes[0];
2686        let x_class = head_classes[1];
2687        let y_class = head_classes[2];
2688        let z_class = head_classes[3];
2689        let atom_classes = |a: u8| (roots[ac4(a, 0) as usize], roots[ac4(a, 1) as usize]);
2690        let atom_rels = [*rel_ll, *rel_lr, *rel_rl, *rel_rr];
2691        let mut rel_wx = None;
2692        let mut rel_xy = None;
2693        let mut rel_yz = None;
2694        let mut rel_zw = None;
2695        for atom_idx in 0..4u8 {
2696            let (c0, c1) = atom_classes(atom_idx);
2697            let bw = c0 == w_class || c1 == w_class;
2698            let bx = c0 == x_class || c1 == x_class;
2699            let by = c0 == y_class || c1 == y_class;
2700            let bz = c0 == z_class || c1 == z_class;
2701            match (bw, bx, by, bz) {
2702                (true, true, false, false) => rel_wx = Some(atom_rels[atom_idx as usize]),
2703                (false, true, true, false) => rel_xy = Some(atom_rels[atom_idx as usize]),
2704                (false, false, true, true) => rel_yz = Some(atom_rels[atom_idx as usize]),
2705                (true, false, false, true) => rel_zw = Some(atom_rels[atom_idx as usize]),
2706                _ => return None,
2707            }
2708        }
2709        Some(Cycle4Semantics {
2710            rel_wx: rel_wx?,
2711            rel_xy: rel_xy?,
2712            rel_yz: rel_yz?,
2713            rel_zw: rel_zw?,
2714        })
2715    }
2716
2717    #[derive(Clone, Copy, Debug, PartialEq, Eq)]
2718    enum Cycle4Grouping {
2719        Default,
2720        Alt,
2721    }
2722
2723    fn build_4cycle_body(s: &Cycle4Semantics, g: Cycle4Grouping) -> RirNode {
2724        let mk_scan = |r: RelId| RirNode::Scan { rel: r };
2725        match g {
2726            Cycle4Grouping::Default => {
2727                let il = RirNode::Join {
2728                    left: Box::new(mk_scan(s.rel_wx)),
2729                    right: Box::new(mk_scan(s.rel_xy)),
2730                    left_keys: vec![1],
2731                    right_keys: vec![0],
2732                    join_type: JoinType::Inner,
2733                };
2734                let ir = RirNode::Join {
2735                    left: Box::new(mk_scan(s.rel_yz)),
2736                    right: Box::new(mk_scan(s.rel_zw)),
2737                    left_keys: vec![1],
2738                    right_keys: vec![0],
2739                    join_type: JoinType::Inner,
2740                };
2741                let outer = RirNode::Join {
2742                    left: Box::new(il),
2743                    right: Box::new(ir),
2744                    left_keys: vec![0, 3],
2745                    right_keys: vec![3, 0],
2746                    join_type: JoinType::Inner,
2747                };
2748                RirNode::Project {
2749                    input: Box::new(outer),
2750                    columns: vec![
2751                        ProjectExpr::Column(0),
2752                        ProjectExpr::Column(1),
2753                        ProjectExpr::Column(3),
2754                        ProjectExpr::Column(5),
2755                    ],
2756                }
2757            }
2758            Cycle4Grouping::Alt => {
2759                let il = RirNode::Join {
2760                    left: Box::new(mk_scan(s.rel_xy)),
2761                    right: Box::new(mk_scan(s.rel_yz)),
2762                    left_keys: vec![1],
2763                    right_keys: vec![0],
2764                    join_type: JoinType::Inner,
2765                };
2766                let ir = RirNode::Join {
2767                    left: Box::new(mk_scan(s.rel_zw)),
2768                    right: Box::new(mk_scan(s.rel_wx)),
2769                    left_keys: vec![1],
2770                    right_keys: vec![0],
2771                    join_type: JoinType::Inner,
2772                };
2773                let outer = RirNode::Join {
2774                    left: Box::new(il),
2775                    right: Box::new(ir),
2776                    left_keys: vec![0, 3],
2777                    right_keys: vec![3, 0],
2778                    join_type: JoinType::Inner,
2779                };
2780                RirNode::Project {
2781                    input: Box::new(outer),
2782                    columns: vec![
2783                        ProjectExpr::Column(5),
2784                        ProjectExpr::Column(0),
2785                        ProjectExpr::Column(1),
2786                        ProjectExpr::Column(3),
2787                    ],
2788                }
2789            }
2790        }
2791    }
2792
2793    pub fn try_reorder_4cycle(body: &RirNode, stats: &StatsManager) -> Option<RirNode> {
2794        let s = match_and_infer_4cycle(body)?;
2795        let _ = (
2796            populated_card(stats, s.rel_wx)?,
2797            populated_card(stats, s.rel_xy)?,
2798            populated_card(stats, s.rel_yz)?,
2799            populated_card(stats, s.rel_zw)?,
2800        );
2801        let est_default = stats
2802            .estimate_join_cardinality(s.rel_wx, s.rel_xy, &[1], &[0])
2803            .saturating_add(stats.estimate_join_cardinality(s.rel_yz, s.rel_zw, &[1], &[0]));
2804        let est_alt = stats
2805            .estimate_join_cardinality(s.rel_xy, s.rel_yz, &[1], &[0])
2806            .saturating_add(stats.estimate_join_cardinality(s.rel_zw, s.rel_wx, &[1], &[0]));
2807        let chosen = if est_alt < est_default {
2808            Cycle4Grouping::Alt
2809        } else {
2810            Cycle4Grouping::Default
2811        };
2812        let candidate = build_4cycle_body(&s, chosen);
2813        if format!("{:?}", candidate) == format!("{:?}", body) {
2814            return None;
2815        }
2816        Some(candidate)
2817    }
2818}
2819
2820#[cfg(test)]
2821mod selectivity_pass_tests {
2822    use super::selectivity_pass;
2823    use crate::Compiler;
2824    use xlog_stats::StatsManager;
2825
2826    fn body_snapshots(plan: &xlog_ir::ExecutionPlan) -> Vec<String> {
2827        plan.rules_by_scc
2828            .iter()
2829            .flatten()
2830            .map(|r| format!("{:?}", r.body))
2831            .collect()
2832    }
2833
2834    #[test]
2835    fn selectivity_pass_is_noop_for_triangle_plan() {
2836        let mut compiler = Compiler::new();
2837        let plan = compiler
2838            .compile("tri(X, Y, Z) :- e1(X, Y), e2(Y, Z), e3(X, Z).")
2839            .expect("compile");
2840        let before = body_snapshots(&plan);
2841        let stats = StatsManager::new();
2842        let mut plan2 = plan.clone();
2843        selectivity_pass::run(&mut plan2, &stats, &std::collections::HashMap::new());
2844        let after = body_snapshots(&plan2);
2845        assert_eq!(
2846            before, after,
2847            "selectivity_pass must preserve every triangle rule body byte-for-byte"
2848        );
2849    }
2850
2851    #[test]
2852    fn selectivity_pass_is_noop_for_4cycle_plan() {
2853        let mut compiler = Compiler::new();
2854        let plan = compiler
2855            .compile("cycle4(W, X, Y, Z) :- e1(W, X), e2(X, Y), e3(Y, Z), e4(Z, W).")
2856            .expect("compile");
2857        let before = body_snapshots(&plan);
2858        let stats = StatsManager::new();
2859        let mut plan2 = plan.clone();
2860        selectivity_pass::run(&mut plan2, &stats, &std::collections::HashMap::new());
2861        let after = body_snapshots(&plan2);
2862        assert_eq!(
2863            before, after,
2864            "selectivity_pass must preserve every 4-cycle rule body byte-for-byte"
2865        );
2866    }
2867
2868    #[test]
2869    fn selectivity_pass_is_noop_for_recursive_scc() {
2870        let mut compiler = Compiler::new();
2871        let plan = compiler
2872            .compile(
2873                "edge(1, 2). edge(2, 3). \
2874                 reach(X, Y) :- edge(X, Y). \
2875                 reach(X, Z) :- reach(X, Y), edge(Y, Z).",
2876            )
2877            .expect("compile");
2878        let before = body_snapshots(&plan);
2879        let stats = StatsManager::new();
2880        let mut plan2 = plan.clone();
2881        selectivity_pass::run(&mut plan2, &stats, &std::collections::HashMap::new());
2882        let after = body_snapshots(&plan2);
2883        assert_eq!(
2884            before, after,
2885            "selectivity_pass must preserve recursive SCC bodies byte-for-byte"
2886        );
2887    }
2888
2889    // ---------------------------------------------------------
2890    // W2.2 — selectivity-driven reordering tests
2891    // ---------------------------------------------------------
2892
2893    use xlog_core::RelId;
2894    use xlog_ir::plan::{CompiledRule, PlanBuilder, Scc};
2895    use xlog_ir::rir::ProjectExpr;
2896    use xlog_ir::{ExecutionPlan, JoinType, RirNode};
2897
2898    /// Build a hand-crafted canonical lowered triangle plan
2899    /// with three Scans at RelId(1), RelId(2), RelId(3) for
2900    /// (e_xy, e_yz, e_xz). Bypasses the optimizer entirely so
2901    /// the W2.2 cert is a clean stats-→-pair-choice
2902    /// observation, not a confounded test of optimizer + W2.2.
2903    ///
2904    /// Default canonical shape (Y-shared inner): inner keys
2905    /// `[1]/[0]`, outer keys `[0,3]/[0,1]`, project `[0,1,3]`.
2906    fn synth_triangle_plan() -> ExecutionPlan {
2907        let inner = RirNode::Join {
2908            left: Box::new(RirNode::Scan { rel: RelId(1) }),
2909            right: Box::new(RirNode::Scan { rel: RelId(2) }),
2910            left_keys: vec![1],
2911            right_keys: vec![0],
2912            join_type: JoinType::Inner,
2913        };
2914        let outer = RirNode::Join {
2915            left: Box::new(inner),
2916            right: Box::new(RirNode::Scan { rel: RelId(3) }),
2917            left_keys: vec![0, 3],
2918            right_keys: vec![0, 1],
2919            join_type: JoinType::Inner,
2920        };
2921        let body = RirNode::Project {
2922            input: Box::new(outer),
2923            columns: vec![
2924                ProjectExpr::Column(0),
2925                ProjectExpr::Column(1),
2926                ProjectExpr::Column(3),
2927            ],
2928        };
2929        let mut builder = PlanBuilder::new();
2930        builder.add_scc(Scc {
2931            id: 0,
2932            predicates: vec!["tri".to_string()],
2933            is_recursive: false,
2934        });
2935        builder.add_rule(
2936            0,
2937            CompiledRule {
2938                head: "tri".to_string(),
2939                body,
2940                meta: Default::default(),
2941            },
2942        );
2943        builder.build()
2944    }
2945
2946    /// Seed a `StatsManager` with three triangle-edge
2947    /// cardinalities at the conventional RelIds (1, 2, 3) used
2948    /// by `synth_triangle_plan`.
2949    fn seed_triangle_stats(c1: u64, c2: u64, c3: u64) -> StatsManager {
2950        let mut stats = StatsManager::new();
2951        for (rid, card) in [(RelId(1), c1), (RelId(2), c2), (RelId(3), c3)] {
2952            stats.register_relation(rid);
2953            stats.update_cardinality(rid, card);
2954        }
2955        stats
2956    }
2957
2958    /// Inspect the (left RelId, right RelId) of the inner Join
2959    /// in a canonical lowered triangle body. Used by W2.2
2960    /// reordering certs.
2961    ///
2962    /// After `compile()` the body is a `MultiWayJoin` whose
2963    /// `fallback` field holds the post-selectivity-pass
2964    /// pre-promotion shape — that's where the inner-pair
2965    /// signature lives. The helper unwraps `MultiWayJoin →
2966    /// fallback` if needed before drilling into the binary
2967    /// Join structure.
2968    fn inspect_triangle_inner_pair(plan: &xlog_ir::ExecutionPlan) -> Option<(RelId, RelId)> {
2969        let body = &plan.rules_by_scc.iter().flatten().next()?.body;
2970        let body = match body {
2971            xlog_ir::RirNode::MultiWayJoin { fallback, .. } => fallback.as_ref(),
2972            other => other,
2973        };
2974        let xlog_ir::RirNode::Project { input, .. } = body else {
2975            return None;
2976        };
2977        let xlog_ir::RirNode::Join { left, .. } = input.as_ref() else {
2978            return None;
2979        };
2980        let xlog_ir::RirNode::Join {
2981            left: l2,
2982            right: r2,
2983            ..
2984        } = left.as_ref()
2985        else {
2986            return None;
2987        };
2988        let xlog_ir::RirNode::Scan { rel: rel_l } = l2.as_ref() else {
2989            return None;
2990        };
2991        let xlog_ir::RirNode::Scan { rel: rel_r } = r2.as_ref() else {
2992            return None;
2993        };
2994        Some((*rel_l, *rel_r))
2995    }
2996
2997    /// W2.2 — snapshot 1: cards favor `(e1, e2)` Y-shared inner.
2998    /// Triangle rule: `tri(X, Y, Z) :- e1(X, Y), e2(Y, Z), e3(X, Z)`.
2999    /// To make Y-shared smallest, give e1 + e2 small cards and
3000    /// e3 a large card so all pair products are dominated by
3001    /// pairs containing e3 — except the pair (e1, e2) which
3002    /// is the smallest product.
3003    #[test]
3004    fn selectivity_pass_picks_y_shared_inner_when_e1_e2_smallest() {
3005        let mut plan = synth_triangle_plan();
3006        // e1=10, e2=10, e3=100_000 → Y-shared (e1⋈e2) smallest.
3007        let stats = seed_triangle_stats(10, 10, 100_000);
3008        selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3009        let pair = inspect_triangle_inner_pair(&plan).expect("inner pair");
3010        // Y-shared inner = (e_xy, e_yz) = (RelId(1), RelId(2)).
3011        assert!(
3012            pair == (RelId(1), RelId(2)) || pair == (RelId(2), RelId(1)),
3013            "expected (RelId(1), RelId(2)) for Y-shared; got {:?}",
3014            pair
3015        );
3016    }
3017
3018    /// W2.2 — snapshot 2: cards favor `(e1, e3)` X-shared inner.
3019    /// e1 + e3 small, e2 large.
3020    #[test]
3021    fn selectivity_pass_picks_x_shared_inner_when_e1_e3_smallest() {
3022        let mut plan = synth_triangle_plan();
3023        // e1=10, e2=100_000, e3=10 → X-shared (e1⋈e3) smallest.
3024        let stats = seed_triangle_stats(10, 100_000, 10);
3025        selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3026        let pair = inspect_triangle_inner_pair(&plan).expect("inner pair");
3027        // X-shared inner = (e_xy, e_xz) = (RelId(1), RelId(3)).
3028        assert!(
3029            pair == (RelId(1), RelId(3)) || pair == (RelId(3), RelId(1)),
3030            "expected (RelId(1), RelId(3)) for X-shared; got {:?}",
3031            pair
3032        );
3033    }
3034
3035    /// W2.2 — snapshot 3: cards favor `(e2, e3)` Z-shared inner.
3036    /// e2 + e3 small, e1 large.
3037    #[test]
3038    fn selectivity_pass_picks_z_shared_inner_when_e2_e3_smallest() {
3039        let mut plan = synth_triangle_plan();
3040        // e1=100_000, e2=10, e3=10 → Z-shared (e2⋈e3) smallest.
3041        let stats = seed_triangle_stats(100_000, 10, 10);
3042        selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3043        let pair = inspect_triangle_inner_pair(&plan).expect("inner pair");
3044        // Z-shared inner = (e_yz, e_xz) = (RelId(2), RelId(3)).
3045        assert!(
3046            pair == (RelId(2), RelId(3)) || pair == (RelId(3), RelId(2)),
3047            "expected (RelId(2), RelId(3)) for Z-shared; got {:?}",
3048            pair
3049        );
3050    }
3051
3052    /// W2.2 — two snapshots produce different inner pairs. Pins
3053    /// "stats drive the order, not deterministic
3054    /// canonicalization." Deterministic canonicalization that
3055    /// ignores stats CANNOT pass this gate.
3056    #[test]
3057    fn selectivity_pass_two_snapshots_produce_different_inner_pairs() {
3058        let mut plan_a = synth_triangle_plan();
3059        let stats_a = seed_triangle_stats(10, 10, 100_000); // Y-shared
3060        selectivity_pass::run(&mut plan_a, &stats_a, &std::collections::HashMap::new());
3061        let pair_a = inspect_triangle_inner_pair(&plan_a).expect("snapshot A pair");
3062
3063        let mut plan_b = synth_triangle_plan();
3064        let stats_b = seed_triangle_stats(100_000, 10, 10); // Z-shared
3065        selectivity_pass::run(&mut plan_b, &stats_b, &std::collections::HashMap::new());
3066        let pair_b = inspect_triangle_inner_pair(&plan_b).expect("snapshot B pair");
3067
3068        let normalize = |(a, b): (RelId, RelId)| -> (RelId, RelId) {
3069            if a.0 <= b.0 {
3070                (a, b)
3071            } else {
3072                (b, a)
3073            }
3074        };
3075        assert_ne!(
3076            normalize(pair_a),
3077            normalize(pair_b),
3078            "two different stats snapshots must produce different inner pairs; \
3079             got A = {:?}, B = {:?}",
3080            pair_a,
3081            pair_b
3082        );
3083    }
3084
3085    /// W2.2 — fallback edge case: relation cards present but no
3086    /// column statistics. The 10% default fallback inside
3087    /// `estimate_join_cardinality` means all three pair
3088    /// estimates collapse to roughly the same ratio. The pass
3089    /// either picks SOME pair or leaves the body unchanged;
3090    /// the test is tolerant by design and documents the
3091    /// uninformative-fallback case explicitly.
3092    #[test]
3093    fn selectivity_pass_with_only_relation_cards_may_pick_arbitrary_pair() {
3094        let mut plan = synth_triangle_plan();
3095        // All three cards equal — no column stats to break ties.
3096        let stats = seed_triangle_stats(100, 100, 100);
3097        selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3098        // Either a triangle inner pair is identifiable (any of
3099        // the three) or the body stays unchanged. Both are OK.
3100        let _ = inspect_triangle_inner_pair(&plan);
3101    }
3102
3103    // ---------------------------------------------------------
3104    // W2.2 — 4-cycle compile-time reordering tests
3105    // ---------------------------------------------------------
3106
3107    /// Build a hand-crafted canonical lowered 4-cycle plan
3108    /// with four Scans at RelId(1), RelId(2), RelId(3), RelId(4)
3109    /// for (e_wx, e_xy, e_yz, e_zw). Bypasses the optimizer.
3110    /// Default canonical bushy shape: inner-left
3111    /// `(e_wx ⋈ e_xy)` on X, inner-right `(e_yz ⋈ e_zw)` on Z,
3112    /// outer keys `[0, 3] / [3, 0]`, project `[0, 1, 3, 5]`.
3113    fn synth_4cycle_plan() -> ExecutionPlan {
3114        let inner_left = RirNode::Join {
3115            left: Box::new(RirNode::Scan { rel: RelId(1) }),
3116            right: Box::new(RirNode::Scan { rel: RelId(2) }),
3117            left_keys: vec![1],
3118            right_keys: vec![0],
3119            join_type: JoinType::Inner,
3120        };
3121        let inner_right = RirNode::Join {
3122            left: Box::new(RirNode::Scan { rel: RelId(3) }),
3123            right: Box::new(RirNode::Scan { rel: RelId(4) }),
3124            left_keys: vec![1],
3125            right_keys: vec![0],
3126            join_type: JoinType::Inner,
3127        };
3128        let outer = RirNode::Join {
3129            left: Box::new(inner_left),
3130            right: Box::new(inner_right),
3131            left_keys: vec![0, 3],
3132            right_keys: vec![3, 0],
3133            join_type: JoinType::Inner,
3134        };
3135        let body = RirNode::Project {
3136            input: Box::new(outer),
3137            columns: vec![
3138                ProjectExpr::Column(0),
3139                ProjectExpr::Column(1),
3140                ProjectExpr::Column(3),
3141                ProjectExpr::Column(5),
3142            ],
3143        };
3144        let mut builder = PlanBuilder::new();
3145        builder.add_scc(Scc {
3146            id: 0,
3147            predicates: vec!["cyc".to_string()],
3148            is_recursive: false,
3149        });
3150        builder.add_rule(
3151            0,
3152            CompiledRule {
3153                head: "cyc".to_string(),
3154                body,
3155                meta: Default::default(),
3156            },
3157        );
3158        builder.build()
3159    }
3160
3161    fn seed_4cycle_stats(c1: u64, c2: u64, c3: u64, c4: u64) -> StatsManager {
3162        let mut stats = StatsManager::new();
3163        for (rid, card) in [
3164            (RelId(1), c1),
3165            (RelId(2), c2),
3166            (RelId(3), c3),
3167            (RelId(4), c4),
3168        ] {
3169            stats.register_relation(rid);
3170            stats.update_cardinality(rid, card);
3171        }
3172        stats
3173    }
3174
3175    /// Recover the 4-cycle inner-grouping signature: `(left_left,
3176    /// left_right, right_left, right_right)` Scan RelIds. Used
3177    /// to identify which grouping the rewriter chose.
3178    fn inspect_4cycle_grouping(
3179        plan: &xlog_ir::ExecutionPlan,
3180    ) -> Option<(RelId, RelId, RelId, RelId)> {
3181        let body = &plan.rules_by_scc.iter().flatten().next()?.body;
3182        let body = match body {
3183            xlog_ir::RirNode::MultiWayJoin { fallback, .. } => fallback.as_ref(),
3184            other => other,
3185        };
3186        let xlog_ir::RirNode::Project { input, .. } = body else {
3187            return None;
3188        };
3189        let xlog_ir::RirNode::Join { left, right, .. } = input.as_ref() else {
3190            return None;
3191        };
3192        let xlog_ir::RirNode::Join {
3193            left: ll,
3194            right: lr,
3195            ..
3196        } = left.as_ref()
3197        else {
3198            return None;
3199        };
3200        let xlog_ir::RirNode::Join {
3201            left: rl,
3202            right: rr,
3203            ..
3204        } = right.as_ref()
3205        else {
3206            return None;
3207        };
3208        let xlog_ir::RirNode::Scan { rel: r_ll } = ll.as_ref() else {
3209            return None;
3210        };
3211        let xlog_ir::RirNode::Scan { rel: r_lr } = lr.as_ref() else {
3212            return None;
3213        };
3214        let xlog_ir::RirNode::Scan { rel: r_rl } = rl.as_ref() else {
3215            return None;
3216        };
3217        let xlog_ir::RirNode::Scan { rel: r_rr } = rr.as_ref() else {
3218            return None;
3219        };
3220        Some((*r_ll, *r_lr, *r_rl, *r_rr))
3221    }
3222
3223    /// W2.2 — 4-cycle: cards favor Default grouping
3224    /// `(e_wx⋈e_xy on X) + (e_yz⋈e_zw on Z)`. Default cost is
3225    /// `est(WX⋈XY)+est(YZ⋈ZW) = 0.1*c1*c2 + 0.1*c3*c4`.
3226    /// Alt cost is `0.1*c2*c3 + 0.1*c4*c1`. Default smaller
3227    /// when `c1*c2 + c3*c4 < c2*c3 + c4*c1`. With
3228    /// (c1=10, c2=10, c3=100_000, c4=100_000):
3229    ///   default = 100 + 10^10 ≈ 10^10.
3230    ///   alt = 10^6 + 10^6 ≈ 2*10^6.
3231    /// → alt is smaller, so this fixture actually favors Alt.
3232    /// Use (c1=10, c2=10, c3=10, c4=10_000_000) instead:
3233    ///   default = 100 + 10^8 = 10^8.
3234    ///   alt = 100 + 10^8 = 10^8 (same).
3235    /// Need uneven c4 vs others: (c1=10, c2=10, c3=10_000_000, c4=10):
3236    ///   default = 100 + 10^8 = 10^8.
3237    ///   alt = 10^8 + 100 = 10^8 (same).
3238    /// Default favored when c1*c2 << c2*c3 AND c3*c4 << c4*c1.
3239    /// I.e., c1 small and c4 small relative to c2 and c3.
3240    /// (c1=10, c2=10_000, c3=10_000, c4=10):
3241    ///   default = 0.1*100_000 + 0.1*100_000 = 20_000.
3242    ///   alt = 0.1*100_000_000 + 0.1*100 = 10_000_010.
3243    /// → Default smaller. ✓
3244    #[test]
3245    fn selectivity_pass_4cycle_picks_default_grouping_when_corners_smallest() {
3246        let mut plan = synth_4cycle_plan();
3247        let stats = seed_4cycle_stats(10, 10_000, 10_000, 10);
3248        selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3249        let (ll, lr, rl, rr) = inspect_4cycle_grouping(&plan).expect("grouping");
3250        // Default: (e_wx, e_xy, e_yz, e_zw) = (RelId(1..4)).
3251        assert_eq!(
3252            (ll, lr, rl, rr),
3253            (RelId(1), RelId(2), RelId(3), RelId(4)),
3254            "expected Default grouping"
3255        );
3256    }
3257
3258    /// W2.2 — 4-cycle: cards favor Alt grouping
3259    /// `(e_xy⋈e_yz on Y) + (e_zw⋈e_wx on W)`. Alt smaller when
3260    /// `c2*c3 + c4*c1 < c1*c2 + c3*c4`. Use
3261    /// (c1=10_000, c2=10, c3=10, c4=10_000):
3262    ///   default = 0.1*100_000 + 0.1*100_000 = 20_000.
3263    ///   alt = 0.1*100 + 0.1*10^8 = 10_000_010.
3264    /// → Default still wins. Need c1*c2 LARGE and c3*c4 LARGE
3265    /// while c2*c3 SMALL and c4*c1 SMALL. Try
3266    /// (c1=10_000, c2=10_000, c3=10, c4=10):
3267    ///   default = 0.1*10^8 + 0.1*100 = 10_000_010.
3268    ///   alt = 0.1*100_000 + 0.1*100_000 = 20_000.
3269    /// → Alt smaller. ✓
3270    #[test]
3271    fn selectivity_pass_4cycle_picks_alt_grouping_when_diagonals_smallest() {
3272        let mut plan = synth_4cycle_plan();
3273        let stats = seed_4cycle_stats(10_000, 10_000, 10, 10);
3274        selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3275        let (ll, lr, rl, rr) = inspect_4cycle_grouping(&plan).expect("grouping");
3276        // Alt: (e_xy, e_yz, e_zw, e_wx) = (RelId(2), RelId(3), RelId(4), RelId(1)).
3277        assert_eq!(
3278            (ll, lr, rl, rr),
3279            (RelId(2), RelId(3), RelId(4), RelId(1)),
3280            "expected Alt grouping"
3281        );
3282    }
3283
3284    /// W2.2 — same plan, two stats snapshots → two different
3285    /// 4-cycle groupings. Pins "stats drive the choice" for
3286    /// 4-cycle.
3287    #[test]
3288    fn selectivity_pass_4cycle_two_snapshots_produce_different_groupings() {
3289        let mut plan_a = synth_4cycle_plan();
3290        let stats_a = seed_4cycle_stats(10, 10_000, 10_000, 10); // Default.
3291        selectivity_pass::run(&mut plan_a, &stats_a, &std::collections::HashMap::new());
3292        let g_a = inspect_4cycle_grouping(&plan_a).expect("grouping a");
3293
3294        let mut plan_b = synth_4cycle_plan();
3295        let stats_b = seed_4cycle_stats(10_000, 10_000, 10, 10); // Alt.
3296        selectivity_pass::run(&mut plan_b, &stats_b, &std::collections::HashMap::new());
3297        let g_b = inspect_4cycle_grouping(&plan_b).expect("grouping b");
3298
3299        assert_ne!(
3300            g_a, g_b,
3301            "two different stats snapshots must produce different 4-cycle groupings; \
3302             got A = {:?}, B = {:?}",
3303            g_a, g_b
3304        );
3305    }
3306
3307    /// W2.2 — 4-cycle missing-stats safety floor: any unseeded
3308    /// relation → body unchanged.
3309    #[test]
3310    fn selectivity_pass_4cycle_skips_when_card_missing() {
3311        let mut plan = synth_4cycle_plan();
3312        // Only seed 3 of 4.
3313        let mut stats = StatsManager::new();
3314        for rid in [RelId(1), RelId(2), RelId(3)] {
3315            stats.register_relation(rid);
3316            stats.update_cardinality(rid, 100);
3317        }
3318        let before = format!("{:?}", plan.rules_by_scc[0][0].body);
3319        selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3320        let after = format!("{:?}", plan.rules_by_scc[0][0].body);
3321        assert_eq!(
3322            before, after,
3323            "missing-stats safety floor must leave body unchanged"
3324        );
3325    }
3326}
3327
3328#[cfg(test)]
3329mod tests {
3330    use super::*;
3331    use xlog_core::ScalarType;
3332    use xlog_ir::{ConstValue, ProjectExpr};
3333    use xlog_stats::ColumnStats;
3334
3335    fn make_stats_manager() -> Arc<StatsManager> {
3336        let mut mgr = StatsManager::new();
3337
3338        // Register test relations with realistic statistics
3339        mgr.register_relation(RelId(1));
3340        mgr.update_cardinality(RelId(1), 10_000);
3341        mgr.update_byte_size(RelId(1), 320_000); // ~32 bytes per row
3342
3343        mgr.register_relation(RelId(2));
3344        mgr.update_cardinality(RelId(2), 5_000);
3345        mgr.update_byte_size(RelId(2), 160_000);
3346
3347        mgr.register_relation(RelId(3));
3348        mgr.update_cardinality(RelId(3), 1_000);
3349        mgr.update_byte_size(RelId(3), 32_000);
3350
3351        // Add column statistics for relation 1
3352        let mut col0 = ColumnStats::new(0, ScalarType::I64);
3353        col0.update_distinct(1000);
3354        col0.update_range(0, 10000);
3355        mgr.add_column_stats(RelId(1), col0);
3356
3357        let mut col1 = ColumnStats::new(1, ScalarType::I64);
3358        col1.update_distinct(100);
3359        mgr.add_column_stats(RelId(1), col1);
3360
3361        Arc::new(mgr)
3362    }
3363
3364    #[test]
3365    fn test_optimizer_new() {
3366        let stats = make_stats_manager();
3367        let optimizer = Optimizer::new(stats);
3368
3369        assert_eq!(optimizer.config().dp_threshold, 10);
3370        assert!(optimizer.config().enable_pushdown);
3371    }
3372
3373    #[test]
3374    fn test_optimizer_with_config() {
3375        let stats = make_stats_manager();
3376        let config = OptimizerConfig {
3377            dp_threshold: 5,
3378            enable_pushdown: false,
3379            ..Default::default()
3380        };
3381        let optimizer = Optimizer::with_config(stats, config);
3382
3383        assert_eq!(optimizer.config().dp_threshold, 5);
3384        assert!(!optimizer.config().enable_pushdown);
3385    }
3386
3387    #[test]
3388    fn test_estimate_scan_cost() {
3389        let stats = make_stats_manager();
3390        let optimizer = Optimizer::new(stats);
3391
3392        let scan = RirNode::Scan { rel: RelId(1) };
3393        let cost = optimizer.estimate_cost(&scan);
3394
3395        assert_eq!(cost.rows, 10_000);
3396        assert!(cost.gpu_mem > 0);
3397        assert_eq!(cost.transfers, 0); // Data on GPU
3398    }
3399
3400    #[test]
3401    fn test_estimate_scan_cost_unknown_relation() {
3402        let stats = Arc::new(StatsManager::new());
3403        let optimizer = Optimizer::new(stats);
3404
3405        let scan = RirNode::Scan { rel: RelId(999) };
3406        let cost = optimizer.estimate_cost(&scan);
3407
3408        // Should use defaults
3409        assert_eq!(cost.rows, 1000);
3410    }
3411
3412    #[test]
3413    fn test_estimate_filter_cost() {
3414        let stats = make_stats_manager();
3415        let optimizer = Optimizer::new(stats);
3416
3417        let filter = RirNode::Filter {
3418            input: Box::new(RirNode::Scan { rel: RelId(1) }),
3419            predicate: Expr::Compare {
3420                left: Box::new(Expr::Column(0)),
3421                op: CompareOp::Eq,
3422                right: Box::new(Expr::Const(ConstValue::I64(42))),
3423            },
3424        };
3425
3426        let cost = optimizer.estimate_cost(&filter);
3427
3428        // Filter should reduce row count
3429        assert!(cost.rows < 10_000);
3430        assert!(cost.rows >= 1);
3431    }
3432
3433    #[test]
3434    fn test_estimate_join_cost() {
3435        let stats = make_stats_manager();
3436        let optimizer = Optimizer::new(stats);
3437
3438        let join = RirNode::Join {
3439            left: Box::new(RirNode::Scan { rel: RelId(1) }),
3440            right: Box::new(RirNode::Scan { rel: RelId(2) }),
3441            left_keys: vec![0],
3442            right_keys: vec![0],
3443            join_type: JoinType::Inner,
3444        };
3445
3446        let cost = optimizer.estimate_cost(&join);
3447
3448        // Should have positive estimates
3449        assert!(cost.rows > 0);
3450        assert!(cost.cpu_cost > 0.0);
3451        assert!(cost.gpu_mem > 0);
3452    }
3453
3454    #[test]
3455    fn test_estimate_join_cost_with_selectivity() {
3456        let mut mgr = StatsManager::new();
3457        mgr.register_relation(RelId(1));
3458        mgr.register_relation(RelId(2));
3459        mgr.update_cardinality(RelId(1), 1000);
3460        mgr.update_cardinality(RelId(2), 500);
3461
3462        // Record a join result to cache selectivity
3463        mgr.record_join_result(RelId(1), RelId(2), vec![0], vec![0], 500_000, 2500);
3464
3465        let optimizer = Optimizer::new(Arc::new(mgr));
3466
3467        let join = RirNode::Join {
3468            left: Box::new(RirNode::Scan { rel: RelId(1) }),
3469            right: Box::new(RirNode::Scan { rel: RelId(2) }),
3470            left_keys: vec![0],
3471            right_keys: vec![0],
3472            join_type: JoinType::Inner,
3473        };
3474
3475        let cost = optimizer.estimate_cost(&join);
3476
3477        // Should use cached selectivity for estimate
3478        assert!(cost.rows > 0);
3479    }
3480
3481    #[test]
3482    fn test_predicate_pushdown_simple_scan() {
3483        let stats = make_stats_manager();
3484        let optimizer = Optimizer::new(stats);
3485
3486        let scan = RirNode::Scan { rel: RelId(1) };
3487        let optimized = optimizer.optimize(scan);
3488
3489        // Scan should pass through unchanged
3490        assert!(matches!(optimized, RirNode::Scan { rel: RelId(1) }));
3491    }
3492
3493    #[test]
3494    fn test_predicate_pushdown_filter_on_scan() {
3495        let stats = make_stats_manager();
3496        let optimizer = Optimizer::new(stats);
3497
3498        let filter = RirNode::Filter {
3499            input: Box::new(RirNode::Scan { rel: RelId(1) }),
3500            predicate: Expr::Compare {
3501                left: Box::new(Expr::Column(0)),
3502                op: CompareOp::Eq,
3503                right: Box::new(Expr::Const(ConstValue::I64(42))),
3504            },
3505        };
3506
3507        let optimized = optimizer.optimize(filter);
3508
3509        // Filter on scan should stay in place
3510        assert!(matches!(optimized, RirNode::Filter { .. }));
3511    }
3512
3513    #[test]
3514    fn test_predicate_pushdown_merges_filters() {
3515        let stats = make_stats_manager();
3516        let optimizer = Optimizer::new(stats);
3517
3518        let nested_filter = RirNode::Filter {
3519            input: Box::new(RirNode::Filter {
3520                input: Box::new(RirNode::Scan { rel: RelId(1) }),
3521                predicate: Expr::Compare {
3522                    left: Box::new(Expr::Column(0)),
3523                    op: CompareOp::Gt,
3524                    right: Box::new(Expr::Const(ConstValue::I64(0))),
3525                },
3526            }),
3527            predicate: Expr::Compare {
3528                left: Box::new(Expr::Column(0)),
3529                op: CompareOp::Lt,
3530                right: Box::new(Expr::Const(ConstValue::I64(100))),
3531            },
3532        };
3533
3534        let optimized = optimizer.optimize(nested_filter);
3535
3536        // Filters should be merged into AND
3537        if let RirNode::Filter { predicate, .. } = optimized {
3538            assert!(matches!(predicate, Expr::And(_)));
3539        } else {
3540            panic!("Expected Filter node");
3541        }
3542    }
3543
3544    #[test]
3545    fn test_predicate_pushdown_through_project() {
3546        let stats = make_stats_manager();
3547        let optimizer = Optimizer::new(stats);
3548
3549        // Filter on projected column that's a pass-through
3550        let plan = RirNode::Filter {
3551            input: Box::new(RirNode::Project {
3552                input: Box::new(RirNode::Scan { rel: RelId(1) }),
3553                columns: vec![ProjectExpr::Column(0), ProjectExpr::Column(1)],
3554            }),
3555            predicate: Expr::Compare {
3556                left: Box::new(Expr::Column(0)),
3557                op: CompareOp::Eq,
3558                right: Box::new(Expr::Const(ConstValue::I64(42))),
3559            },
3560        };
3561
3562        let optimized = optimizer.optimize(plan);
3563
3564        // Filter should be pushed below project
3565        assert!(matches!(optimized, RirNode::Project { .. }));
3566        if let RirNode::Project { input, .. } = optimized {
3567            assert!(matches!(*input, RirNode::Filter { .. }));
3568        }
3569    }
3570
3571    #[test]
3572    fn test_predicate_pushdown_into_join() {
3573        let stats = make_stats_manager();
3574        let optimizer = Optimizer::new(stats);
3575
3576        // Filter on left side column only
3577        let plan = RirNode::Filter {
3578            input: Box::new(RirNode::Join {
3579                left: Box::new(RirNode::Scan { rel: RelId(1) }),
3580                right: Box::new(RirNode::Scan { rel: RelId(2) }),
3581                left_keys: vec![0],
3582                right_keys: vec![0],
3583                join_type: JoinType::Inner,
3584            }),
3585            predicate: Expr::Compare {
3586                left: Box::new(Expr::Column(0)), // Left side column
3587                op: CompareOp::Eq,
3588                right: Box::new(Expr::Const(ConstValue::I64(42))),
3589            },
3590        };
3591
3592        let optimized = optimizer.optimize(plan);
3593
3594        // Filter should be pushed into left side of join
3595        if let RirNode::Join { left, .. } = optimized {
3596            assert!(matches!(*left, RirNode::Filter { .. }));
3597        } else {
3598            panic!("Expected Join node");
3599        }
3600    }
3601
3602    #[test]
3603    fn test_plan_cost_total() {
3604        let cost = PlanCost {
3605            rows: 1000,
3606            cpu_cost: 100.0,
3607            gpu_mem: 1_000_000,
3608            transfers: 2,
3609        };
3610
3611        let total = cost.total_cost(100.0);
3612
3613        // cpu_cost + gpu_mem*0.001 + transfers*100
3614        // 100.0 + 1000.0 + 200.0 = 1300.0
3615        assert!((total - 1300.0).abs() < 0.001);
3616    }
3617
3618    #[test]
3619    fn test_plan_cost_then() {
3620        let cost1 = PlanCost {
3621            rows: 1000,
3622            cpu_cost: 50.0,
3623            gpu_mem: 500,
3624            transfers: 1,
3625        };
3626
3627        let cost2 = PlanCost {
3628            rows: 500,
3629            cpu_cost: 25.0,
3630            gpu_mem: 800,
3631            transfers: 1,
3632        };
3633
3634        let combined = cost1.then(cost2);
3635
3636        assert_eq!(combined.rows, 500); // Takes output rows from second
3637        assert_eq!(combined.cpu_cost, 75.0);
3638        assert_eq!(combined.gpu_mem, 800); // Peak memory
3639        assert_eq!(combined.transfers, 2);
3640    }
3641
3642    #[test]
3643    fn test_optimizer_config_default() {
3644        let config = OptimizerConfig::default();
3645
3646        assert_eq!(config.dp_threshold, 10);
3647        assert!((config.index_heat_threshold - 0.7).abs() < 0.001);
3648        assert!(config.enable_pushdown);
3649        assert!((config.default_filter_selectivity - 0.1).abs() < 0.001);
3650    }
3651
3652    #[test]
3653    fn test_should_use_greedy() {
3654        let stats = make_stats_manager();
3655        let config = OptimizerConfig {
3656            dp_threshold: 2,
3657            ..Default::default()
3658        };
3659        let optimizer = Optimizer::with_config(stats, config);
3660
3661        // Single relation: should NOT use greedy
3662        let single = RirNode::Scan { rel: RelId(1) };
3663        assert!(!optimizer.should_use_greedy(&single));
3664
3665        // Three relations: should use greedy (threshold is 2)
3666        let multi = RirNode::Join {
3667            left: Box::new(RirNode::Join {
3668                left: Box::new(RirNode::Scan { rel: RelId(1) }),
3669                right: Box::new(RirNode::Scan { rel: RelId(2) }),
3670                left_keys: vec![0],
3671                right_keys: vec![0],
3672                join_type: JoinType::Inner,
3673            }),
3674            right: Box::new(RirNode::Scan { rel: RelId(3) }),
3675            left_keys: vec![0],
3676            right_keys: vec![0],
3677            join_type: JoinType::Inner,
3678        };
3679        assert!(optimizer.should_use_greedy(&multi));
3680    }
3681
3682    #[test]
3683    fn test_recommend_indexes() {
3684        let mut mgr = StatsManager::new();
3685        mgr.register_relation(RelId(1));
3686        mgr.register_relation(RelId(2));
3687
3688        // Heat up relation 1 extensively
3689        for _ in 0..50 {
3690            mgr.record_access(RelId(1));
3691        }
3692
3693        let optimizer = Optimizer::new(Arc::new(mgr));
3694        let recommendations = optimizer.recommend_indexes();
3695
3696        assert!(recommendations.contains(&RelId(1)));
3697        assert!(!recommendations.contains(&RelId(2)));
3698    }
3699
3700    #[test]
3701    fn test_estimate_groupby_cost() {
3702        let stats = make_stats_manager();
3703        let optimizer = Optimizer::new(stats);
3704
3705        let groupby = RirNode::GroupBy {
3706            input: Box::new(RirNode::Scan { rel: RelId(1) }),
3707            key_cols: vec![0],
3708            aggs: vec![(1, xlog_core::AggOp::Sum)],
3709        };
3710
3711        let cost = optimizer.estimate_cost(&groupby);
3712
3713        // GroupBy should reduce row count
3714        assert!(cost.rows < 10_000);
3715        assert!(cost.rows >= 1);
3716    }
3717
3718    #[test]
3719    fn test_estimate_union_cost() {
3720        let stats = make_stats_manager();
3721        let optimizer = Optimizer::new(stats);
3722
3723        let union = RirNode::Union {
3724            inputs: vec![
3725                RirNode::Scan { rel: RelId(1) },
3726                RirNode::Scan { rel: RelId(2) },
3727            ],
3728        };
3729
3730        let cost = optimizer.estimate_cost(&union);
3731
3732        // Union sums row counts
3733        assert_eq!(cost.rows, 15_000); // 10000 + 5000
3734    }
3735
3736    #[test]
3737    fn test_estimate_distinct_cost() {
3738        let stats = make_stats_manager();
3739        let optimizer = Optimizer::new(stats);
3740
3741        let distinct = RirNode::Distinct {
3742            input: Box::new(RirNode::Scan { rel: RelId(1) }),
3743            key_cols: vec![0],
3744        };
3745
3746        let cost = optimizer.estimate_cost(&distinct);
3747
3748        // Distinct reduces rows
3749        assert!(cost.rows <= 10_000);
3750        assert!(cost.rows >= 1);
3751    }
3752
3753    #[test]
3754    fn test_estimate_diff_cost() {
3755        let stats = make_stats_manager();
3756        let optimizer = Optimizer::new(stats);
3757
3758        let diff = RirNode::Diff {
3759            left: Box::new(RirNode::Scan { rel: RelId(1) }),
3760            right: Box::new(RirNode::Scan { rel: RelId(2) }),
3761        };
3762
3763        let cost = optimizer.estimate_cost(&diff);
3764
3765        // Diff reduces left side
3766        assert!(cost.rows <= 10_000);
3767        assert!(cost.rows >= 1);
3768    }
3769
3770    #[test]
3771    fn test_estimate_fixpoint_cost() {
3772        let stats = make_stats_manager();
3773        let optimizer = Optimizer::new(stats);
3774
3775        let fixpoint = RirNode::Fixpoint {
3776            scc_id: 0,
3777            base: Box::new(RirNode::Scan { rel: RelId(1) }),
3778            recursive: Box::new(RirNode::Scan { rel: RelId(1) }),
3779            delta_rel: RelId(10),
3780            full_rel: RelId(11),
3781        };
3782
3783        let cost = optimizer.estimate_cost(&fixpoint);
3784
3785        // Fixpoint accumulates rows across iterations
3786        assert!(cost.rows >= 10_000);
3787    }
3788
3789    #[test]
3790    fn test_predicate_selectivity_equality() {
3791        let stats = make_stats_manager();
3792        let optimizer = Optimizer::new(stats);
3793
3794        let scan = RirNode::Scan { rel: RelId(1) };
3795
3796        // Equality predicate
3797        let eq_pred = Expr::Compare {
3798            left: Box::new(Expr::Column(0)),
3799            op: CompareOp::Eq,
3800            right: Box::new(Expr::Const(ConstValue::I64(42))),
3801        };
3802
3803        let selectivity = optimizer.estimate_predicate_selectivity(&eq_pred, &scan);
3804
3805        // With 1000 distinct values, selectivity should be ~0.001
3806        assert!(selectivity < 0.01);
3807        assert!(selectivity > 0.0);
3808    }
3809
3810    #[test]
3811    fn test_predicate_selectivity_and() {
3812        let stats = make_stats_manager();
3813        let optimizer = Optimizer::new(stats);
3814
3815        let scan = RirNode::Scan { rel: RelId(1) };
3816
3817        // AND of two predicates
3818        let and_pred = Expr::And(vec![
3819            Expr::Compare {
3820                left: Box::new(Expr::Column(0)),
3821                op: CompareOp::Gt,
3822                right: Box::new(Expr::Const(ConstValue::I64(0))),
3823            },
3824            Expr::Compare {
3825                left: Box::new(Expr::Column(0)),
3826                op: CompareOp::Lt,
3827                right: Box::new(Expr::Const(ConstValue::I64(100))),
3828            },
3829        ]);
3830
3831        let selectivity = optimizer.estimate_predicate_selectivity(&and_pred, &scan);
3832
3833        // Product of individual selectivities (0.33 * 0.33 ≈ 0.11)
3834        assert!(selectivity < 0.5);
3835        assert!(selectivity > 0.0);
3836    }
3837
3838    #[test]
3839    fn test_predicate_selectivity_not() {
3840        let stats = make_stats_manager();
3841        let optimizer = Optimizer::new(stats);
3842
3843        let scan = RirNode::Scan { rel: RelId(1) };
3844
3845        // NOT of equality
3846        let not_pred = Expr::Not(Box::new(Expr::Compare {
3847            left: Box::new(Expr::Column(0)),
3848            op: CompareOp::Eq,
3849            right: Box::new(Expr::Const(ConstValue::I64(42))),
3850        }));
3851
3852        let selectivity = optimizer.estimate_predicate_selectivity(&not_pred, &scan);
3853
3854        // NOT(equality) should have high selectivity
3855        assert!(selectivity > 0.9);
3856    }
3857
3858    #[test]
3859    fn test_join_type_semi() {
3860        let stats = make_stats_manager();
3861        let optimizer = Optimizer::new(stats);
3862
3863        let semi_join = RirNode::Join {
3864            left: Box::new(RirNode::Scan { rel: RelId(1) }),
3865            right: Box::new(RirNode::Scan { rel: RelId(2) }),
3866            left_keys: vec![0],
3867            right_keys: vec![0],
3868            join_type: JoinType::Semi,
3869        };
3870
3871        let cost = optimizer.estimate_cost(&semi_join);
3872
3873        // Semi join outputs at most left side rows
3874        assert!(cost.rows <= 10_000);
3875    }
3876
3877    #[test]
3878    fn test_join_type_anti() {
3879        let stats = make_stats_manager();
3880        let optimizer = Optimizer::new(stats);
3881
3882        let anti_join = RirNode::Join {
3883            left: Box::new(RirNode::Scan { rel: RelId(1) }),
3884            right: Box::new(RirNode::Scan { rel: RelId(2) }),
3885            left_keys: vec![0],
3886            right_keys: vec![0],
3887            join_type: JoinType::Anti,
3888        };
3889
3890        let cost = optimizer.estimate_cost(&anti_join);
3891
3892        // Anti join outputs at most left side rows
3893        assert!(cost.rows <= 10_000);
3894    }
3895
3896    #[test]
3897    fn test_pushdown_disabled() {
3898        let stats = make_stats_manager();
3899        let config = OptimizerConfig {
3900            enable_pushdown: false,
3901            ..Default::default()
3902        };
3903        let optimizer = Optimizer::with_config(stats, config);
3904
3905        // Filter that could be pushed
3906        let plan = RirNode::Filter {
3907            input: Box::new(RirNode::Filter {
3908                input: Box::new(RirNode::Scan { rel: RelId(1) }),
3909                predicate: Expr::Compare {
3910                    left: Box::new(Expr::Column(0)),
3911                    op: CompareOp::Gt,
3912                    right: Box::new(Expr::Const(ConstValue::I64(0))),
3913                },
3914            }),
3915            predicate: Expr::Compare {
3916                left: Box::new(Expr::Column(0)),
3917                op: CompareOp::Lt,
3918                right: Box::new(Expr::Const(ConstValue::I64(100))),
3919            },
3920        };
3921
3922        let optimized = optimizer.optimize(plan.clone());
3923
3924        // With pushdown disabled, structure should remain the same
3925        // (outer filter, inner filter, scan)
3926        if let RirNode::Filter { input, .. } = optimized {
3927            assert!(matches!(*input, RirNode::Filter { .. }));
3928        } else {
3929            panic!("Expected Filter node");
3930        }
3931    }
3932
3933    #[test]
3934    fn test_collect_columns() {
3935        let expr = Expr::And(vec![
3936            Expr::Compare {
3937                left: Box::new(Expr::Column(0)),
3938                op: CompareOp::Eq,
3939                right: Box::new(Expr::Column(2)),
3940            },
3941            Expr::Compare {
3942                left: Box::new(Expr::Column(1)),
3943                op: CompareOp::Gt,
3944                right: Box::new(Expr::Const(ConstValue::I64(0))),
3945            },
3946        ]);
3947
3948        let cols = Optimizer::collect_columns(&expr);
3949
3950        assert!(cols.contains(&0));
3951        assert!(cols.contains(&1));
3952        assert!(cols.contains(&2));
3953    }
3954
3955    #[test]
3956    fn test_flatten_and() {
3957        let nested = Expr::And(vec![
3958            Expr::And(vec![
3959                Expr::Compare {
3960                    left: Box::new(Expr::Column(0)),
3961                    op: CompareOp::Eq,
3962                    right: Box::new(Expr::Const(ConstValue::I64(1))),
3963                },
3964                Expr::Compare {
3965                    left: Box::new(Expr::Column(1)),
3966                    op: CompareOp::Eq,
3967                    right: Box::new(Expr::Const(ConstValue::I64(2))),
3968                },
3969            ]),
3970            Expr::Compare {
3971                left: Box::new(Expr::Column(2)),
3972                op: CompareOp::Eq,
3973                right: Box::new(Expr::Const(ConstValue::I64(3))),
3974            },
3975        ]);
3976
3977        let flattened = Optimizer::flatten_and(&nested);
3978
3979        assert_eq!(flattened.len(), 3);
3980    }
3981
3982    #[test]
3983    fn test_conjoin_single() {
3984        let single = vec![Expr::Compare {
3985            left: Box::new(Expr::Column(0)),
3986            op: CompareOp::Eq,
3987            right: Box::new(Expr::Const(ConstValue::I64(42))),
3988        }];
3989
3990        let result = Optimizer::conjoin(single);
3991
3992        assert!(matches!(result, Expr::Compare { .. }));
3993    }
3994
3995    #[test]
3996    fn test_conjoin_multiple() {
3997        let multiple = vec![
3998            Expr::Compare {
3999                left: Box::new(Expr::Column(0)),
4000                op: CompareOp::Eq,
4001                right: Box::new(Expr::Const(ConstValue::I64(1))),
4002            },
4003            Expr::Compare {
4004                left: Box::new(Expr::Column(1)),
4005                op: CompareOp::Eq,
4006                right: Box::new(Expr::Const(ConstValue::I64(2))),
4007            },
4008        ];
4009
4010        let result = Optimizer::conjoin(multiple);
4011
4012        assert!(matches!(result, Expr::And(_)));
4013    }
4014
4015    #[test]
4016    fn test_predicate_pushdown_with_schemas() {
4017        // Regression test: ensure predicate pushdown uses schemas for accurate width estimation.
4018        // Without schemas, the optimizer could incorrectly remap column indices.
4019        let stats = make_stats_manager();
4020        let mut optimizer = Optimizer::new(stats);
4021
4022        // Set up schemas: left has 3 columns, right has 3 columns
4023        let left_schema = Schema::new(vec![
4024            ("c0".to_string(), xlog_core::ScalarType::Symbol),
4025            ("c1".to_string(), xlog_core::ScalarType::Symbol),
4026            ("c2".to_string(), xlog_core::ScalarType::Symbol),
4027        ]);
4028        let right_schema = Schema::new(vec![
4029            ("c0".to_string(), xlog_core::ScalarType::Symbol),
4030            ("c1".to_string(), xlog_core::ScalarType::Symbol),
4031            ("c2".to_string(), xlog_core::ScalarType::U32),
4032        ]);
4033
4034        let mut schemas = HashMap::new();
4035        schemas.insert(RelId(1), left_schema);
4036        schemas.insert(RelId(2), right_schema);
4037        optimizer.set_schemas(schemas);
4038
4039        // Filter on Column(5) which is in the right side (left_width=3, so column 5-3=2 in right)
4040        let plan = RirNode::Filter {
4041            input: Box::new(RirNode::Join {
4042                left: Box::new(RirNode::Scan { rel: RelId(1) }),
4043                right: Box::new(RirNode::Scan { rel: RelId(2) }),
4044                left_keys: vec![0],
4045                right_keys: vec![0],
4046                join_type: JoinType::Inner,
4047            }),
4048            predicate: Expr::Compare {
4049                left: Box::new(Expr::Column(5)), // Right side column (index 5 = 3 + 2)
4050                op: CompareOp::Ge,
4051                right: Box::new(Expr::Const(ConstValue::U32(4))),
4052            },
4053        };
4054
4055        let optimized = optimizer.optimize(plan);
4056
4057        // Filter should be pushed into right side of join with Column(2) (remapped from 5-3=2)
4058        if let RirNode::Join { right, .. } = optimized {
4059            if let RirNode::Filter { predicate, .. } = *right {
4060                if let Expr::Compare { left, .. } = predicate {
4061                    if let Expr::Column(idx) = *left {
4062                        assert_eq!(
4063                            idx, 2,
4064                            "Column should be remapped to 2 (5 - left_width(3) = 2)"
4065                        );
4066                    } else {
4067                        panic!("Expected Column expression");
4068                    }
4069                } else {
4070                    panic!("Expected Compare predicate");
4071                }
4072            } else {
4073                panic!("Expected Filter on right side of join");
4074            }
4075        } else {
4076            panic!("Expected Join node");
4077        }
4078    }
4079
4080    /// v0.6.5 slice 1: optimizer arms for `MultiWayJoin`.
4081    ///
4082    /// The promoter runs after `Optimizer::optimize` in `Compiler`, so
4083    /// these arms are unreachable in production. They exist for compile
4084    /// safety and to pin the documented semantics: `optimize` returns
4085    /// the node unchanged, `estimate_width` reports the head arity from
4086    /// `output_columns`, `estimate_cost` is the sum of input costs, and
4087    /// `find_column_relation` returns `None` (per slice 1 guardrail).
4088    ///
4089    /// v0.6.5 slice 2 (D5) extends each test below to also exercise a
4090    /// synthesized 4-input `MultiWayJoin` via [`build_4input_multiway`].
4091    /// This pins shape-agnosticism: the arms must NOT hard-code
4092    /// `inputs.len() == 3` or `output_columns.len() == 3`. Slice 2a
4093    /// (4-way) will produce real 4-input bodies through the promoter;
4094    /// these tests are the load-bearing guard against silent regression.
4095    fn build_canonical_triangle_multiway() -> RirNode {
4096        let scan_xy = RirNode::Scan { rel: RelId(1) };
4097        let scan_yz = RirNode::Scan { rel: RelId(2) };
4098        let scan_xz = RirNode::Scan { rel: RelId(3) };
4099        let inner_join = RirNode::Join {
4100            left: Box::new(scan_xy.clone()),
4101            right: Box::new(scan_yz.clone()),
4102            left_keys: vec![1],
4103            right_keys: vec![0],
4104            join_type: JoinType::Inner,
4105        };
4106        let outer_join = RirNode::Join {
4107            left: Box::new(inner_join),
4108            right: Box::new(scan_xz.clone()),
4109            left_keys: vec![0, 3],
4110            right_keys: vec![0, 1],
4111            join_type: JoinType::Inner,
4112        };
4113        let fallback = RirNode::Project {
4114            input: Box::new(outer_join),
4115            columns: vec![
4116                ProjectExpr::Column(0),
4117                ProjectExpr::Column(1),
4118                ProjectExpr::Column(3),
4119            ],
4120        };
4121        RirNode::MultiWayJoin {
4122            inputs: vec![scan_xy, scan_yz, scan_xz],
4123            slot_vars: vec![
4124                vec![Some(0), Some(1)],
4125                vec![Some(1), Some(2)],
4126                vec![Some(0), Some(2)],
4127            ],
4128            output_columns: vec![
4129                ProjectExpr::Column(0),
4130                ProjectExpr::Column(1),
4131                ProjectExpr::Column(3),
4132            ],
4133            fallback: Box::new(fallback),
4134            plan: None,
4135            var_order: None,
4136        }
4137    }
4138
4139    /// v0.6.5 slice 2 (D5): synthesized 4-input `MultiWayJoin` for
4140    /// shape-agnosticism testing. Slice 1's promoter is triangle-only,
4141    /// so this shape never reaches `Optimizer` through the production
4142    /// pipeline; the tests below exercise the optimizer arms directly.
4143    ///
4144    /// Inputs reuse `RelId(1, 2, 3, 1)` — RelId(1) repeats — so the
4145    /// stats manager registered in `make_stats_manager` covers all
4146    /// four scans. Cost floor is `2*10_000 + 5_000 + 1_000 = 26_000`.
4147    fn build_4input_multiway() -> RirNode {
4148        let scans = [RelId(1), RelId(2), RelId(3), RelId(1)]
4149            .map(|rel| RirNode::Scan { rel })
4150            .to_vec();
4151        // 4-cycle slot_vars [[A,B],[B,C],[C,D],[A,D]].
4152        let slot_vars = vec![
4153            vec![Some(0u32), Some(1)],
4154            vec![Some(1u32), Some(2)],
4155            vec![Some(2u32), Some(3)],
4156            vec![Some(0u32), Some(3)],
4157        ];
4158        // 4-arity head projection (no real semantic meaning — the
4159        // synthesized fallback is a stub).
4160        let output_columns = vec![
4161            ProjectExpr::Column(0),
4162            ProjectExpr::Column(1),
4163            ProjectExpr::Column(2),
4164            ProjectExpr::Column(3),
4165        ];
4166        // Stub fallback: the optimizer arms do not execute fallback,
4167        // so any RirNode is fine. Use Unit to keep the fixture small.
4168        let fallback = RirNode::Unit;
4169        RirNode::MultiWayJoin {
4170            inputs: scans,
4171            slot_vars,
4172            output_columns,
4173            fallback: Box::new(fallback),
4174            plan: None,
4175            var_order: None,
4176        }
4177    }
4178
4179    #[test]
4180    fn optimize_returns_multiway_unchanged() {
4181        let optimizer = Optimizer::new(make_stats_manager());
4182        for node in [build_canonical_triangle_multiway(), build_4input_multiway()] {
4183            let optimized = optimizer.optimize(node.clone());
4184            match (&node, &optimized) {
4185                (
4186                    RirNode::MultiWayJoin {
4187                        inputs: a_in,
4188                        output_columns: a_out,
4189                        ..
4190                    },
4191                    RirNode::MultiWayJoin {
4192                        inputs: b_in,
4193                        output_columns: b_out,
4194                        ..
4195                    },
4196                ) => {
4197                    assert_eq!(a_in.len(), b_in.len());
4198                    assert_eq!(a_out.len(), b_out.len());
4199                }
4200                _ => panic!("optimize() must return a MultiWayJoin"),
4201            }
4202        }
4203    }
4204
4205    #[test]
4206    fn estimate_width_uses_output_columns_arity() {
4207        let optimizer = Optimizer::new(make_stats_manager());
4208        // Canonical triangle: 3 head columns.
4209        assert_eq!(
4210            optimizer.estimate_width(&build_canonical_triangle_multiway()),
4211            3
4212        );
4213        // 4-input synthesized: 4 head columns. Locks shape-
4214        // agnosticism — the arm must use output_columns.len(),
4215        // not a hard-coded 3.
4216        assert_eq!(optimizer.estimate_width(&build_4input_multiway()), 4);
4217    }
4218
4219    #[test]
4220    fn estimate_cost_sums_input_costs() {
4221        let optimizer = Optimizer::new(make_stats_manager());
4222
4223        // Canonical triangle: rels 1, 2, 3 with cardinalities
4224        // 10_000 + 5_000 + 1_000 = 16_000.
4225        let cost_tri = optimizer.estimate_cost(&build_canonical_triangle_multiway());
4226        assert!(
4227            cost_tri.rows >= 16_000,
4228            "expected cost.rows >= 16000, got {}",
4229            cost_tri.rows
4230        );
4231
4232        // 4-input synthesized: rels 1, 2, 3, 1 → 2*10_000 + 5_000 +
4233        // 1_000 = 26_000. The arm sums all four inputs; cost grows.
4234        // Locks shape-agnosticism — the arm must walk every entry
4235        // in `inputs`, not a hard-coded 3.
4236        let cost_4 = optimizer.estimate_cost(&build_4input_multiway());
4237        assert!(
4238            cost_4.rows >= 26_000,
4239            "expected 4-input cost.rows >= 26000, got {}",
4240            cost_4.rows
4241        );
4242        assert!(
4243            cost_4.rows > cost_tri.rows,
4244            "4-input cost ({}) must exceed triangle cost ({})",
4245            cost_4.rows,
4246            cost_tri.rows
4247        );
4248    }
4249
4250    #[test]
4251    fn find_column_relation_returns_none_for_multiway() {
4252        let optimizer = Optimizer::new(make_stats_manager());
4253        // Per slice 1 guardrail: no column-to-input mapping in this
4254        // slice. Half-mapped is more dangerous than None. The arm
4255        // must return None regardless of arity — slice 2 strengthens
4256        // this to also check the 4-input synthesized shape so a
4257        // future "let's just return inputs[col_idx % len]" patch
4258        // gets caught.
4259        for node in [build_canonical_triangle_multiway(), build_4input_multiway()] {
4260            for col in 0..node.referenced_relations().len() {
4261                assert!(
4262                    optimizer.find_column_relation(&node, col).is_none(),
4263                    "find_column_relation must return None for any \
4264                     MultiWayJoin column (col={})",
4265                    col,
4266                );
4267            }
4268        }
4269    }
4270}