samkhya_datafusion/
physical_plan.rs1use std::any::Any;
44use std::fmt;
45use std::sync::Arc;
46
47use datafusion::arrow::datatypes::SchemaRef;
48use datafusion::common::{Result, Statistics};
49use datafusion::execution::TaskContext;
50use datafusion::physical_plan::{
51 DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
52};
53
54#[derive(Debug, Clone)]
61pub struct SamkhyaStatsExec {
62 input: Arc<dyn ExecutionPlan>,
63 stats: Statistics,
64}
65
66impl SamkhyaStatsExec {
67 pub fn new(input: Arc<dyn ExecutionPlan>, stats: Statistics) -> Self {
70 Self { input, stats }
71 }
72
73 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
75 &self.input
76 }
77
78 pub fn override_statistics(&self) -> &Statistics {
80 &self.stats
81 }
82}
83
84impl DisplayAs for SamkhyaStatsExec {
85 fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
86 match t {
87 DisplayFormatType::Default | DisplayFormatType::Verbose => {
88 write!(f, "SamkhyaStatsExec: num_rows={:?}", self.stats.num_rows)
89 }
90 }
91 }
92}
93
94impl ExecutionPlan for SamkhyaStatsExec {
95 fn name(&self) -> &str {
96 "SamkhyaStatsExec"
97 }
98
99 fn as_any(&self) -> &dyn Any {
100 self
101 }
102
103 fn schema(&self) -> SchemaRef {
104 self.input.schema()
105 }
106
107 fn properties(&self) -> &PlanProperties {
108 self.input.properties()
113 }
114
115 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
116 vec![&self.input]
117 }
118
119 fn maintains_input_order(&self) -> Vec<bool> {
120 vec![true]
122 }
123
124 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
125 vec![false]
126 }
127
128 fn with_new_children(
129 self: Arc<Self>,
130 children: Vec<Arc<dyn ExecutionPlan>>,
131 ) -> Result<Arc<dyn ExecutionPlan>> {
132 let new_input = children
136 .into_iter()
137 .next()
138 .expect("SamkhyaStatsExec has exactly one child");
139 Ok(Arc::new(SamkhyaStatsExec::new(
140 new_input,
141 self.stats.clone(),
142 )))
143 }
144
145 fn execute(
146 &self,
147 partition: usize,
148 context: Arc<TaskContext>,
149 ) -> Result<SendableRecordBatchStream> {
150 self.input.execute(partition, context)
151 }
152
153 fn statistics(&self) -> Result<Statistics> {
154 Ok(self.stats.clone())
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use datafusion::arrow::array::Int64Array;
162 use datafusion::arrow::datatypes::{DataType, Field, Schema};
163 use datafusion::arrow::record_batch::RecordBatch;
164 use datafusion::common::stats::Precision;
165 use datafusion::datasource::{MemTable, TableProvider};
166 use datafusion::execution::context::SessionContext;
167
168 async fn tiny_input_exec() -> Arc<dyn ExecutionPlan> {
169 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
170 let batch = RecordBatch::try_new(
171 Arc::clone(&schema),
172 vec![Arc::new(Int64Array::from(vec![1, 2, 3]))],
173 )
174 .unwrap();
175 let mem = Arc::new(MemTable::try_new(Arc::clone(&schema), vec![vec![batch]]).unwrap());
176 let ctx = SessionContext::new();
177 let state = ctx.state();
178 let session: &dyn datafusion::catalog::Session = &state;
179 mem.scan(session, None, &[], None).await.unwrap()
180 }
181
182 #[tokio::test(flavor = "multi_thread")]
183 async fn wrapper_reports_override_stats() {
184 let inner = tiny_input_exec().await;
185 let mut stats = Statistics::new_unknown(inner.schema().as_ref());
186 stats.num_rows = Precision::Inexact(42);
187 let wrapped: Arc<dyn ExecutionPlan> = Arc::new(SamkhyaStatsExec::new(inner, stats));
188 let s = wrapped.statistics().expect("stats present");
189 assert_eq!(s.num_rows, Precision::Inexact(42));
190 }
191
192 #[tokio::test(flavor = "multi_thread")]
193 async fn with_new_children_preserves_override() {
194 let inner = tiny_input_exec().await;
195 let mut stats = Statistics::new_unknown(inner.schema().as_ref());
196 stats.num_rows = Precision::Inexact(7);
197 let wrapped: Arc<dyn ExecutionPlan> =
198 Arc::new(SamkhyaStatsExec::new(Arc::clone(&inner), stats));
199 let rebuilt = Arc::clone(&wrapped)
203 .with_new_children(vec![inner])
204 .expect("rebuild");
205 assert_eq!(
206 rebuilt.statistics().unwrap().num_rows,
207 Precision::Inexact(7)
208 );
209 }
210}