Skip to main content

samkhya_datafusion/
physical_plan.rs

1//! `SamkhyaStatsExec` — the [`ExecutionPlan`]-layer wrapper that actually
2//! flows samkhya-corrected statistics into DataFusion 46's physical plan.
3//!
4//! # Why this exists
5//!
6//! In DataFusion 46.0.1 the mainline physical planner does **not** call
7//! [`TableProvider::statistics()`] when it constructs the leaf
8//! `ExecutionPlan` for a scan: it just invokes
9//! [`TableProvider::scan()`] and uses whatever exec is returned. The
10//! exec's own [`ExecutionPlan::statistics()`] is what later operators
11//! (`FilterExec`, `ProjectionExec`, `HashJoinExec`, …) propagate up the
12//! tree.
13//!
14//! To get samkhya's corrections into the row-count numbers reported by
15//! `ctx.state().create_physical_plan(&plan).await?.statistics()`, we
16//! therefore need to override `statistics()` at the *physical* layer.
17//! `SamkhyaStatsExec` is a thin passthrough wrapper for that single
18//! purpose:
19//!
20//! * everything (schema, partitioning, equivalence, execute) delegates to
21//!   the inner [`ExecutionPlan`];
22//! * `statistics()` returns the override [`Statistics`] supplied at
23//!   construction time, marked `Precision::Inexact` per the LpBound
24//!   conservative posture;
25//! * `with_new_children` rebuilds the wrapper around the new child,
26//!   preserving the override — so subsequent physical optimizer rules can
27//!   reshape the tree without losing the corrected stats.
28//!
29//! # Where it gets installed
30//!
31//! [`crate::SamkhyaTableProvider`] installs this wrapper from inside its
32//! `scan()` implementation: the inner provider returns its native exec,
33//! and the table provider wraps it in `SamkhyaStatsExec` carrying the
34//! samkhya-corrected `Statistics`. The mainline planner then sees the
35//! wrapped exec as the scan leaf, so its `statistics()` are the samkhya
36//! values — and propagation through filters / projections / joins
37//! produces a final plan whose `.statistics()?.num_rows` reflects the
38//! corrections.
39//!
40//! [`TableProvider::statistics()`]: datafusion::datasource::TableProvider::statistics
41//! [`TableProvider::scan()`]: datafusion::datasource::TableProvider::scan
42
43use std::any::Any;
44use std::fmt;
45use std::sync::Arc;
46
47use datafusion::arrow::datatypes::SchemaRef;
48use datafusion::common::{Result, Statistics};
49use datafusion::execution::TaskContext;
50use datafusion::physical_plan::{
51    DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
52};
53
54/// Passthrough [`ExecutionPlan`] wrapper that publishes samkhya-corrected
55/// statistics from `statistics()` while delegating every other method to
56/// the inner exec.
57///
58/// Construct via [`SamkhyaStatsExec::new`]; the wrapper holds an
59/// `Arc<dyn ExecutionPlan>` plus the `Statistics` it should report.
60#[derive(Debug, Clone)]
61pub struct SamkhyaStatsExec {
62    input: Arc<dyn ExecutionPlan>,
63    stats: Statistics,
64}
65
66impl SamkhyaStatsExec {
67    /// Wrap `input` so that calls to [`ExecutionPlan::statistics()`] on
68    /// the result return `stats` rather than the input's defaults.
69    pub fn new(input: Arc<dyn ExecutionPlan>, stats: Statistics) -> Self {
70        Self { input, stats }
71    }
72
73    /// Borrow the inner plan.
74    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
75        &self.input
76    }
77
78    /// Borrow the override statistics.
79    pub fn override_statistics(&self) -> &Statistics {
80        &self.stats
81    }
82}
83
84impl DisplayAs for SamkhyaStatsExec {
85    fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
86        match t {
87            DisplayFormatType::Default | DisplayFormatType::Verbose => {
88                write!(f, "SamkhyaStatsExec: num_rows={:?}", self.stats.num_rows)
89            }
90        }
91    }
92}
93
94impl ExecutionPlan for SamkhyaStatsExec {
95    fn name(&self) -> &str {
96        "SamkhyaStatsExec"
97    }
98
99    fn as_any(&self) -> &dyn Any {
100        self
101    }
102
103    fn schema(&self) -> SchemaRef {
104        self.input.schema()
105    }
106
107    fn properties(&self) -> &PlanProperties {
108        // Passthrough: our output partitioning / ordering / equivalence
109        // match the inner plan exactly. Borrowing the inner cache keeps
110        // the wrapper allocation-free and avoids drift if the inner plan
111        // is itself rewritten by other physical-optimizer rules.
112        self.input.properties()
113    }
114
115    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
116        vec![&self.input]
117    }
118
119    fn maintains_input_order(&self) -> Vec<bool> {
120        // Passthrough preserves ordering trivially.
121        vec![true]
122    }
123
124    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
125        vec![false]
126    }
127
128    fn with_new_children(
129        self: Arc<Self>,
130        children: Vec<Arc<dyn ExecutionPlan>>,
131    ) -> Result<Arc<dyn ExecutionPlan>> {
132        // Preserve the override across rewrites — other physical
133        // optimizer rules will call `with_new_children` to swap our
134        // input, and we want the corrected stats to ride along.
135        let new_input = children
136            .into_iter()
137            .next()
138            .expect("SamkhyaStatsExec has exactly one child");
139        Ok(Arc::new(SamkhyaStatsExec::new(
140            new_input,
141            self.stats.clone(),
142        )))
143    }
144
145    fn execute(
146        &self,
147        partition: usize,
148        context: Arc<TaskContext>,
149    ) -> Result<SendableRecordBatchStream> {
150        self.input.execute(partition, context)
151    }
152
153    fn statistics(&self) -> Result<Statistics> {
154        Ok(self.stats.clone())
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use datafusion::arrow::array::Int64Array;
162    use datafusion::arrow::datatypes::{DataType, Field, Schema};
163    use datafusion::arrow::record_batch::RecordBatch;
164    use datafusion::common::stats::Precision;
165    use datafusion::datasource::{MemTable, TableProvider};
166    use datafusion::execution::context::SessionContext;
167
168    async fn tiny_input_exec() -> Arc<dyn ExecutionPlan> {
169        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
170        let batch = RecordBatch::try_new(
171            Arc::clone(&schema),
172            vec![Arc::new(Int64Array::from(vec![1, 2, 3]))],
173        )
174        .unwrap();
175        let mem = Arc::new(MemTable::try_new(Arc::clone(&schema), vec![vec![batch]]).unwrap());
176        let ctx = SessionContext::new();
177        let state = ctx.state();
178        let session: &dyn datafusion::catalog::Session = &state;
179        mem.scan(session, None, &[], None).await.unwrap()
180    }
181
182    #[tokio::test(flavor = "multi_thread")]
183    async fn wrapper_reports_override_stats() {
184        let inner = tiny_input_exec().await;
185        let mut stats = Statistics::new_unknown(inner.schema().as_ref());
186        stats.num_rows = Precision::Inexact(42);
187        let wrapped: Arc<dyn ExecutionPlan> = Arc::new(SamkhyaStatsExec::new(inner, stats));
188        let s = wrapped.statistics().expect("stats present");
189        assert_eq!(s.num_rows, Precision::Inexact(42));
190    }
191
192    #[tokio::test(flavor = "multi_thread")]
193    async fn with_new_children_preserves_override() {
194        let inner = tiny_input_exec().await;
195        let mut stats = Statistics::new_unknown(inner.schema().as_ref());
196        stats.num_rows = Precision::Inexact(7);
197        let wrapped: Arc<dyn ExecutionPlan> =
198            Arc::new(SamkhyaStatsExec::new(Arc::clone(&inner), stats));
199        // Swap the child for an identical plan — the override must ride
200        // through the rebuild, otherwise downstream physical optimizer
201        // rules would erase samkhya's corrections.
202        let rebuilt = Arc::clone(&wrapped)
203            .with_new_children(vec![inner])
204            .expect("rebuild");
205        assert_eq!(
206            rebuilt.statistics().unwrap().num_rows,
207            Precision::Inexact(7)
208        );
209    }
210}