samkhya_datafusion/table_provider.rs
1//! `SamkhyaTableProvider` — the primary integration point for injecting
2//! samkhya-corrected column statistics into DataFusion's query planning.
3//!
4//! # Wrapping point: `TableProvider::statistics()`
5//!
6//! DataFusion attaches statistics to table providers, not to logical-plan
7//! nodes. The [`TableProvider`] trait exposes a `statistics()` hook
8//! (returning `Option<Statistics>`) that the planner consults when reasoning
9//! about cardinality, join order, and filter selectivity. Rewriting a
10//! `LogicalPlan` to "inject" stats is the wrong layer — that is observe-only
11//! plumbing. The right layer is a `TableProvider` shim that delegates every
12//! method to an inner provider *except* `statistics()`, where it folds in
13//! samkhya's feedback-driven corrections.
14//!
15//! We considered three wrapping points and chose the first:
16//!
17//! 1. **`TableProvider::statistics()`** (this module). Clean, stable surface
18//! in DataFusion 46. The planner calls it during analysis. Every adapter
19//! (Parquet, CSV, MemTable, Iceberg) flows through the same hook, so the
20//! shim is provider-agnostic.
21//! 2. `ExecutionPlan::statistics()`. Lower in the stack — would require
22//! wrapping the scan-side `ExecutionPlan` returned from `scan()`. Useful
23//! when the inner provider's logical stats are absent but its physical
24//! plan has them; not our situation today.
25//! 3. `OptimizerRule` rewriting `TableScan::source`. The original scaffold
26//! direction. The rewrite must construct a new `TableSource` (the logical
27//! counterpart of `TableProvider`) — duplicate state, version-fragile,
28//! and never propagates into the physical layer where the planner
29//! actually consults stats. Kept around as observe-only telemetry
30//! ([`crate::SamkhyaOptimizerRule`]).
31//!
32//! # LpBound posture
33//!
34//! Every value translated into DataFusion's `Precision<T>` is wrapped as
35//! [`Precision::Inexact`]. samkhya's corrections are feedback-driven
36//! estimates clamped by the LpBound pessimistic ceiling; they are never
37//! exact catalog counts. `Inexact` is the precision DataFusion's
38//! cost-based optimizer treats as "use this, but do not assume zero error".
39
40use std::any::Any;
41use std::borrow::Cow;
42use std::collections::HashMap;
43use std::sync::Arc;
44use std::sync::atomic::{AtomicUsize, Ordering};
45
46use async_trait::async_trait;
47use datafusion::arrow::datatypes::SchemaRef;
48use datafusion::catalog::Session;
49use datafusion::common::stats::Precision;
50use datafusion::common::{ColumnStatistics, Constraints, Result, Statistics};
51use datafusion::datasource::{TableProvider, TableType};
52use datafusion::logical_expr::dml::InsertOp;
53use datafusion::logical_expr::{Expr, LogicalPlan, TableProviderFilterPushDown};
54use datafusion::physical_plan::ExecutionPlan;
55use samkhya_core::stats::ColumnStats;
56
57use crate::physical_plan::SamkhyaStatsExec;
58use crate::stats_provider::to_datafusion_column_statistics;
59
60/// A [`TableProvider`] wrapper that overrides `statistics()` with
61/// samkhya-corrected column statistics while delegating every other method
62/// to the inner provider.
63///
64/// # Builder
65///
66/// ```ignore
67/// use std::sync::Arc;
68/// use samkhya_datafusion::SamkhyaTableProvider;
69/// use samkhya_core::stats::ColumnStats;
70///
71/// let wrapped = SamkhyaTableProvider::new(Arc::new(inner))
72/// .with_column_stats(0, ColumnStats::new().with_row_count(999).with_distinct_count(42));
73/// ```
74///
75/// # Stats fold semantics
76///
77/// `statistics()` builds a `Statistics` whose per-column entries come from
78/// the samkhya override map where present, falling back to the inner
79/// provider's stats (or `ColumnStatistics::new_unknown()` if the inner
80/// provider returns `None`). Table-level `num_rows` is taken from the
81/// override map's most authoritative `row_count`: the maximum across all
82/// override entries, since samkhya's per-column stats describe the same
83/// underlying relation. If no override carries a row count, the inner
84/// provider's `num_rows` is preserved.
85#[derive(Debug)]
86pub struct SamkhyaTableProvider {
87 inner: Arc<dyn TableProvider>,
88 overrides: HashMap<usize, ColumnStats>,
89 /// Number of times `statistics()` has been invoked by the planner.
90 /// Exposed for integration tests; not part of the public optimization
91 /// contract.
92 stats_calls: AtomicUsize,
93}
94
95impl SamkhyaTableProvider {
96 /// Wrap an existing provider. No overrides are installed until
97 /// [`Self::with_column_stats`] is called.
98 pub fn new(inner: Arc<dyn TableProvider>) -> Self {
99 Self {
100 inner,
101 overrides: HashMap::new(),
102 stats_calls: AtomicUsize::new(0),
103 }
104 }
105
106 /// Install a samkhya override for the column at `col_idx`.
107 ///
108 /// Indices refer to positions in the inner provider's [`SchemaRef`].
109 /// Existing overrides for the same index are replaced.
110 pub fn with_column_stats(mut self, col_idx: usize, stats: ColumnStats) -> Self {
111 self.overrides.insert(col_idx, stats);
112 self
113 }
114
115 /// Number of times `statistics()` has been called on this wrapper.
116 ///
117 /// Useful for assertions in integration tests that verify the planner
118 /// actually consulted the corrected stats.
119 pub fn stats_call_count(&self) -> usize {
120 self.stats_calls.load(Ordering::SeqCst)
121 }
122
123 /// Borrow the override map. Read-only access for diagnostics.
124 pub fn overrides(&self) -> &HashMap<usize, ColumnStats> {
125 &self.overrides
126 }
127}
128
129#[async_trait]
130impl TableProvider for SamkhyaTableProvider {
131 fn as_any(&self) -> &dyn Any {
132 self
133 }
134
135 fn schema(&self) -> SchemaRef {
136 self.inner.schema()
137 }
138
139 fn constraints(&self) -> Option<&Constraints> {
140 self.inner.constraints()
141 }
142
143 fn table_type(&self) -> TableType {
144 self.inner.table_type()
145 }
146
147 fn get_table_definition(&self) -> Option<&str> {
148 self.inner.get_table_definition()
149 }
150
151 fn get_logical_plan(&self) -> Option<Cow<'_, LogicalPlan>> {
152 self.inner.get_logical_plan()
153 }
154
155 fn get_column_default(&self, column: &str) -> Option<&Expr> {
156 self.inner.get_column_default(column)
157 }
158
159 async fn scan(
160 &self,
161 state: &dyn Session,
162 projection: Option<&Vec<usize>>,
163 filters: &[Expr],
164 limit: Option<usize>,
165 ) -> Result<Arc<dyn ExecutionPlan>> {
166 // Ask the inner provider for its native scan exec, then wrap it
167 // in `SamkhyaStatsExec` so the physical layer publishes the
168 // samkhya-corrected `Statistics` to every downstream operator.
169 //
170 // This is the actual injection path: DataFusion 46's mainline
171 // planner does not consult `TableProvider::statistics()` when
172 // building the physical plan — it calls `scan()` and trusts the
173 // returned `ExecutionPlan::statistics()`. So the only reliable
174 // way to flow corrected row counts into
175 // `physical.statistics()?.num_rows` is to override at the exec
176 // level, here.
177 //
178 // If we have no overrides installed we still wrap, using the
179 // statistics() fold as-is — the cost is one cheap delegation
180 // call per execute()/statistics() and the inner provider's
181 // values are preserved by the merge in `self.statistics()`.
182 let inner_plan = self.inner.scan(state, projection, filters, limit).await?;
183
184 // Project the table-level Statistics onto the scan's *output*
185 // schema (which honours `projection`), so the wrapped exec
186 // reports column_statistics aligned to the columns it actually
187 // emits — not the full table schema. This matches what
188 // `TableProvider`-aware execs (`DataSourceExec`) already do.
189 let full_stats = self
190 .statistics()
191 .unwrap_or_else(|| Statistics::new_unknown(self.inner.schema().as_ref()));
192 let output_stats = full_stats.project(projection);
193
194 Ok(Arc::new(SamkhyaStatsExec::new(inner_plan, output_stats)))
195 }
196
197 fn supports_filters_pushdown(
198 &self,
199 filters: &[&Expr],
200 ) -> Result<Vec<TableProviderFilterPushDown>> {
201 self.inner.supports_filters_pushdown(filters)
202 }
203
204 /// Fold samkhya overrides into the inner provider's `Statistics`.
205 ///
206 /// Schema order is preserved: column `i` in the returned
207 /// `column_statistics` corresponds to field `i` of `self.schema()`.
208 fn statistics(&self) -> Option<Statistics> {
209 // Record the call so tests can assert the planner consulted us.
210 self.stats_calls.fetch_add(1, Ordering::SeqCst);
211
212 let schema = self.inner.schema();
213 let n_fields = schema.fields().len();
214
215 // Start from the inner provider's stats; fall back to an unknown
216 // skeleton sized to the schema so we always return Some(_).
217 let mut base = self
218 .inner
219 .statistics()
220 .unwrap_or_else(|| Statistics::new_unknown(schema.as_ref()));
221
222 // Defensive: if the inner provider returned a column_statistics vec
223 // whose length disagrees with the schema, normalise to schema size.
224 if base.column_statistics.len() != n_fields {
225 base.column_statistics = Statistics::unknown_column(schema.as_ref());
226 }
227
228 // Per-column merge: override wins where present, inner is preserved
229 // otherwise. samkhya values are translated as Inexact per the
230 // LpBound conservative posture.
231 for (col_idx, override_stats) in &self.overrides {
232 if *col_idx >= n_fields {
233 // Index out of range — skip rather than panic; this can
234 // happen if the schema changes under us.
235 continue;
236 }
237 let translated = to_datafusion_column_statistics(override_stats);
238 base.column_statistics[*col_idx] =
239 merge_column_stats(base.column_statistics[*col_idx].clone(), translated);
240 }
241
242 // Table-level row count: take the max row_count across overrides
243 // (they all describe the same relation, so any populated value is
244 // a corrected estimate of |R|). If no override carries a row
245 // count, keep the inner provider's value.
246 //
247 // WAVE5-RC2: plan-memory-monotonic guard. Never publish a row
248 // count smaller than the inner provider's native estimate. The
249 // hash-join build-side sizing in DataFusion 46 picks the smaller
250 // side as the build side; if samkhya under-estimates and the
251 // actual data is much larger, the build hash table grows past
252 // its sized allocation and the planner walks into an OOM. Capping
253 // the published row count at `max(samkhya, native)` preserves
254 // samkhya's win when it has a larger / more accurate NDV-derived
255 // row count, while never pushing the planner toward a smaller
256 // build side than it would have chosen with no samkhya input.
257 // Symmetric guard on `SamkhyaStatsExec::statistics()` enforces
258 // the same invariant at the physical layer.
259 let override_row_count = self.overrides.values().filter_map(|s| s.row_count).max();
260 if let Some(rc) = override_row_count {
261 let rc_usize = rc as usize;
262 let monotone_rc = match base.num_rows {
263 Precision::Exact(n) | Precision::Inexact(n) => rc_usize.max(n),
264 Precision::Absent => rc_usize,
265 };
266 base.num_rows = Precision::Inexact(monotone_rc);
267 // Total byte size: if the inner provider reported it, relax to
268 // inexact since the row count has shifted; otherwise leave
269 // absent.
270 base.total_byte_size = match base.total_byte_size {
271 Precision::Exact(n) | Precision::Inexact(n) => Precision::Inexact(n),
272 Precision::Absent => Precision::Absent,
273 };
274 }
275
276 Some(base)
277 }
278
279 async fn insert_into(
280 &self,
281 state: &dyn Session,
282 input: Arc<dyn ExecutionPlan>,
283 insert_op: InsertOp,
284 ) -> Result<Arc<dyn ExecutionPlan>> {
285 self.inner.insert_into(state, input, insert_op).await
286 }
287}
288
289/// Merge a samkhya-translated `ColumnStatistics` over a base one.
290///
291/// Fields where the override is `Precision::Absent` fall through to the
292/// base. Fields where the override carries an `Inexact` value win for
293/// null_count / max_value / min_value / sum_value. **`distinct_count`
294/// applies the WAVE5-RC2 plan-memory-monotonic guard** — the published
295/// value is `max(samkhya_ndv, native_ndv)`. NDV drives hash-join
296/// build-side sizing; never publishing a smaller distinct count than
297/// DataFusion's native estimate prevents the corrected arm from
298/// pushing the planner toward a smaller build hash table than it would
299/// have chosen with no samkhya input.
300fn merge_column_stats(base: ColumnStatistics, ovr: ColumnStatistics) -> ColumnStatistics {
301 ColumnStatistics {
302 null_count: pick(base.null_count, ovr.null_count),
303 max_value: pick(base.max_value, ovr.max_value),
304 min_value: pick(base.min_value, ovr.min_value),
305 sum_value: pick(base.sum_value, ovr.sum_value),
306 distinct_count: pick_max_usize(base.distinct_count, ovr.distinct_count),
307 }
308}
309
310/// Plan-memory-monotonic merge for `Precision<usize>` cardinality
311/// fields. Returns `Precision::Inexact(max(base, ovr))` when both
312/// carry a value, the present one when only one does, and
313/// `Precision::Absent` when neither does. Used for `distinct_count`
314/// merges so samkhya never publishes an NDV smaller than the inner
315/// provider would have on its own.
316fn pick_max_usize(base: Precision<usize>, ovr: Precision<usize>) -> Precision<usize> {
317 let base_val = match base {
318 Precision::Exact(n) | Precision::Inexact(n) => Some(n),
319 Precision::Absent => None,
320 };
321 let ovr_val = match ovr {
322 Precision::Exact(n) | Precision::Inexact(n) => Some(n),
323 Precision::Absent => None,
324 };
325 match (base_val, ovr_val) {
326 (Some(b), Some(o)) => Precision::Inexact(b.max(o)),
327 (Some(b), None) => Precision::Inexact(b),
328 (None, Some(o)) => Precision::Inexact(o),
329 (None, None) => Precision::Absent,
330 }
331}
332
333fn pick<T>(base: Precision<T>, ovr: Precision<T>) -> Precision<T>
334where
335 T: std::fmt::Debug + Clone + PartialEq + Eq + PartialOrd,
336{
337 match ovr {
338 Precision::Absent => base,
339 other => other,
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use datafusion::arrow::array::Int64Array;
347 use datafusion::arrow::datatypes::{DataType, Field, Schema};
348 use datafusion::arrow::record_batch::RecordBatch;
349 use datafusion::datasource::MemTable;
350
351 fn tiny_mem_table() -> Arc<MemTable> {
352 let schema = Arc::new(Schema::new(vec![
353 Field::new("a", DataType::Int64, false),
354 Field::new("b", DataType::Int64, false),
355 ]));
356 let batch = RecordBatch::try_new(
357 Arc::clone(&schema),
358 vec![
359 Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
360 Arc::new(Int64Array::from(vec![10, 20, 30, 40, 50])),
361 ],
362 )
363 .unwrap();
364 Arc::new(MemTable::try_new(schema, vec![vec![batch]]).unwrap())
365 }
366
367 #[test]
368 fn builder_records_overrides() {
369 let inner = tiny_mem_table();
370 let wrapped = SamkhyaTableProvider::new(inner)
371 .with_column_stats(0, ColumnStats::new().with_row_count(999));
372 assert_eq!(wrapped.overrides().len(), 1);
373 assert_eq!(wrapped.overrides()[&0].row_count, Some(999));
374 }
375
376 #[test]
377 fn statistics_overrides_row_count() {
378 let inner = tiny_mem_table();
379 let wrapped = SamkhyaTableProvider::new(inner).with_column_stats(
380 0,
381 ColumnStats::new()
382 .with_row_count(999)
383 .with_distinct_count(42),
384 );
385 let stats = wrapped.statistics().expect("statistics present");
386 assert_eq!(stats.num_rows, Precision::Inexact(999));
387 assert_eq!(
388 stats.column_statistics[0].distinct_count,
389 Precision::Inexact(42)
390 );
391 assert_eq!(wrapped.stats_call_count(), 1);
392 }
393
394 #[test]
395 fn statistics_falls_back_for_unoverridden_columns() {
396 let inner = tiny_mem_table();
397 let wrapped = SamkhyaTableProvider::new(inner)
398 .with_column_stats(0, ColumnStats::new().with_distinct_count(7));
399 let stats = wrapped.statistics().expect("statistics present");
400 assert_eq!(
401 stats.column_statistics[0].distinct_count,
402 Precision::Inexact(7)
403 );
404 // Column 1 has no override and the inner MemTable does not report
405 // stats — so the slot stays at Absent.
406 assert_eq!(stats.column_statistics[1].distinct_count, Precision::Absent);
407 }
408
409 #[test]
410 fn out_of_range_override_is_ignored() {
411 let inner = tiny_mem_table();
412 let wrapped = SamkhyaTableProvider::new(inner)
413 .with_column_stats(99, ColumnStats::new().with_distinct_count(123));
414 // No panic, statistics still produced.
415 let stats = wrapped.statistics().expect("statistics present");
416 assert_eq!(stats.column_statistics.len(), 2);
417 }
418
419 /// WAVE5-RC2: when the inner provider's native row count exceeds the
420 /// samkhya override, the published value is the inner provider's
421 /// estimate, not the (smaller) samkhya value. Prevents the planner
422 /// from picking a smaller hash-join build side than baseline would.
423 ///
424 /// Uses a minimal mock provider that returns a known
425 /// `Precision::Inexact(5)` for num_rows, since `MemTable`'s default
426 /// stats path leaves num_rows as `Precision::Absent` (which would
427 /// trip the fallback branch, not the monotone-cap branch).
428 #[test]
429 fn statistics_row_count_caps_at_max_of_samkhya_and_native() {
430 use async_trait::async_trait;
431 use datafusion::catalog::Session;
432 use datafusion::common::Result as DfResult;
433 use datafusion::datasource::{TableProvider, TableType};
434 use datafusion::logical_expr::Expr;
435 use datafusion::physical_plan::ExecutionPlan;
436
437 #[derive(Debug)]
438 struct MockProvider {
439 schema: SchemaRef,
440 native_rows: usize,
441 }
442
443 #[async_trait]
444 impl TableProvider for MockProvider {
445 fn as_any(&self) -> &dyn Any {
446 self
447 }
448 fn schema(&self) -> SchemaRef {
449 Arc::clone(&self.schema)
450 }
451 fn table_type(&self) -> TableType {
452 TableType::Base
453 }
454 async fn scan(
455 &self,
456 _state: &dyn Session,
457 _projection: Option<&Vec<usize>>,
458 _filters: &[Expr],
459 _limit: Option<usize>,
460 ) -> DfResult<Arc<dyn ExecutionPlan>> {
461 unreachable!("scan not exercised by this test")
462 }
463 fn statistics(&self) -> Option<Statistics> {
464 let mut s = Statistics::new_unknown(self.schema.as_ref());
465 s.num_rows = Precision::Inexact(self.native_rows);
466 Some(s)
467 }
468 }
469
470 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
471 let inner: Arc<dyn TableProvider> = Arc::new(MockProvider {
472 schema: Arc::clone(&schema),
473 native_rows: 5,
474 });
475 let wrapped = SamkhyaTableProvider::new(inner)
476 .with_column_stats(0, ColumnStats::new().with_row_count(3));
477 let stats = wrapped.statistics().expect("statistics present");
478 assert_eq!(
479 stats.num_rows,
480 Precision::Inexact(5),
481 "monotone cap must publish max(samkhya=3, native=5)=5, not the smaller samkhya estimate"
482 );
483 }
484
485 /// WAVE5-RC2: symmetric column-level guard. When samkhya's NDV
486 /// override is smaller than the inner provider's native NDV, publish
487 /// the native (larger) value to keep hash-join build sides on the
488 /// safe side.
489 #[test]
490 fn statistics_distinct_count_caps_at_max_of_samkhya_and_native() {
491 // Hand-construct a base ColumnStatistics with a known native
492 // distinct_count and feed it through merge_column_stats with a
493 // smaller samkhya override. Expected: max() wins.
494 let base = ColumnStatistics {
495 null_count: Precision::Absent,
496 max_value: Precision::Absent,
497 min_value: Precision::Absent,
498 sum_value: Precision::Absent,
499 distinct_count: Precision::Inexact(1000),
500 };
501 let ovr = ColumnStatistics {
502 null_count: Precision::Absent,
503 max_value: Precision::Absent,
504 min_value: Precision::Absent,
505 sum_value: Precision::Absent,
506 distinct_count: Precision::Inexact(50),
507 };
508 let merged = merge_column_stats(base, ovr);
509 assert_eq!(
510 merged.distinct_count,
511 Precision::Inexact(1000),
512 "merge must publish max(samkhya, native) distinct_count"
513 );
514 }
515}