vibesql_storage/statistics/
column.rs

1//! Column-level statistics for selectivity estimation
2
3use std::collections::HashMap;
4
5use vibesql_types::SqlValue;
6
7use super::histogram::{BucketStrategy, Histogram};
8
9/// Statistics for a single column
10#[derive(Debug, Clone)]
11pub struct ColumnStatistics {
12    /// Number of distinct values (cardinality)
13    pub n_distinct: usize,
14
15    /// Number of NULL values
16    pub null_count: usize,
17
18    /// Minimum value (for range queries)
19    pub min_value: Option<SqlValue>,
20
21    /// Maximum value (for range queries)
22    pub max_value: Option<SqlValue>,
23
24    /// Most common values with their frequencies (top 10)
25    pub most_common_values: Vec<(SqlValue, f64)>,
26
27    /// Optional histogram for improved selectivity estimation (Phase 5.1)
28    pub histogram: Option<Histogram>,
29}
30
31impl ColumnStatistics {
32    /// Compute statistics for a column by scanning all rows
33    pub fn compute(rows: &[crate::Row], column_idx: usize) -> Self {
34        Self::compute_with_histogram(rows, column_idx, false, 100, BucketStrategy::EqualDepth)
35    }
36
37    /// Compute statistics with optional histogram support
38    ///
39    /// # Arguments
40    /// * `rows` - The rows to analyze
41    /// * `column_idx` - Index of the column
42    /// * `enable_histogram` - Whether to build a histogram
43    /// * `num_buckets` - Number of histogram buckets (default: 100)
44    /// * `bucket_strategy` - Histogram bucketing strategy
45    pub fn compute_with_histogram(
46        rows: &[crate::Row],
47        column_idx: usize,
48        enable_histogram: bool,
49        num_buckets: usize,
50        bucket_strategy: BucketStrategy,
51    ) -> Self {
52        let mut distinct_values = std::collections::HashSet::new();
53        let mut null_count = 0;
54        let mut min_value: Option<SqlValue> = None;
55        let mut max_value: Option<SqlValue> = None;
56        let mut value_counts: HashMap<SqlValue, usize> = HashMap::new();
57        let mut non_null_values = Vec::new();
58
59        for row in rows {
60            if column_idx >= row.values.len() {
61                continue;
62            }
63
64            let value = &row.values[column_idx];
65
66            if value.is_null() {
67                null_count += 1;
68                continue;
69            }
70
71            distinct_values.insert(value.clone());
72            *value_counts.entry(value.clone()).or_insert(0) += 1;
73            non_null_values.push(value.clone());
74
75            // Track min/max
76            match (&min_value, &max_value) {
77                (None, None) => {
78                    min_value = Some(value.clone());
79                    max_value = Some(value.clone());
80                }
81                (Some(min), Some(max)) => {
82                    if value < min {
83                        min_value = Some(value.clone());
84                    }
85                    if value > max {
86                        max_value = Some(value.clone());
87                    }
88                }
89                _ => unreachable!(),
90            }
91        }
92
93        // Extract most common values (top 10)
94        let total_non_null = rows.len() - null_count;
95        let mut mcvs: Vec<_> = value_counts
96            .into_iter()
97            .map(|(val, count)| {
98                let frequency =
99                    if total_non_null > 0 { count as f64 / total_non_null as f64 } else { 0.0 };
100                (val, frequency)
101            })
102            .collect();
103        mcvs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
104        mcvs.truncate(10);
105
106        // Build histogram if enabled and we have enough data
107        let histogram = if enable_histogram && non_null_values.len() > 10 {
108            Some(Histogram::create(&non_null_values, num_buckets, bucket_strategy))
109        } else {
110            None
111        };
112
113        ColumnStatistics {
114            n_distinct: distinct_values.len(),
115            null_count,
116            min_value,
117            max_value,
118            most_common_values: mcvs,
119            histogram,
120        }
121    }
122
123    /// Estimate selectivity of equality predicate: col = value
124    ///
125    /// Returns fraction of rows expected to match (0.0 to 1.0)
126    /// Uses histogram if available for improved accuracy (Phase 5.1)
127    pub fn estimate_eq_selectivity(&self, value: &SqlValue) -> f64 {
128        if value.is_null() {
129            return 0.0; // NULLs don't match in SQL equality
130        }
131
132        // Check if value is in most common values
133        if let Some((_, freq)) = self.most_common_values.iter().find(|(v, _)| v == value) {
134            return *freq;
135        }
136
137        // Use histogram if available (Phase 5.1)
138        if let Some(ref histogram) = self.histogram {
139            return histogram.estimate_equality_selectivity(value);
140        }
141
142        // Fallback: assume uniform distribution for non-MCVs
143        if self.n_distinct > 0 {
144            1.0 / (self.n_distinct as f64)
145        } else {
146            0.0
147        }
148    }
149
150    /// Estimate selectivity of inequality: col != value
151    pub fn estimate_ne_selectivity(&self, value: &SqlValue) -> f64 {
152        1.0 - self.estimate_eq_selectivity(value)
153    }
154
155    /// Estimate selectivity of range predicate: col > value or col < value
156    ///
157    /// Uses histogram if available (Phase 5.1), otherwise falls back to
158    /// min/max-based linear interpolation (assumes uniform distribution)
159    pub fn estimate_range_selectivity(&self, value: &SqlValue, operator: &str) -> f64 {
160        // Use histogram if available for better accuracy (Phase 5.1)
161        if let Some(ref histogram) = self.histogram {
162            return histogram.estimate_range_selectivity(operator, value);
163        }
164
165        // Fallback to min/max-based estimation
166        match (operator, &self.min_value, &self.max_value) {
167            (">", Some(min), Some(max)) | (">=", Some(min), Some(max)) => {
168                if value < min {
169                    return 1.0; // All values satisfy
170                }
171                if value >= max {
172                    return if operator == ">=" && value == max {
173                        1.0 / (self.n_distinct as f64)
174                    } else {
175                        0.0
176                    };
177                }
178                // Linear interpolation (rough estimate)
179                0.33
180            }
181            ("<", Some(min), Some(max)) | ("<=", Some(min), Some(max)) => {
182                if value > max {
183                    return 1.0;
184                }
185                if value <= min {
186                    return if operator == "<=" && value == min {
187                        1.0 / (self.n_distinct as f64)
188                    } else {
189                        0.0
190                    };
191                }
192                // Linear interpolation
193                0.33
194            }
195            _ => 0.33, // Default fallback
196        }
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use vibesql_types::SqlValue;
203
204    use super::*;
205    use crate::Row;
206
207    #[test]
208    fn test_column_statistics_basic() {
209        let rows = vec![
210            Row::new(vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("a"))]),
211            Row::new(vec![SqlValue::Integer(2), SqlValue::Varchar(arcstr::ArcStr::from("b"))]),
212            Row::new(vec![SqlValue::Integer(3), SqlValue::Varchar(arcstr::ArcStr::from("a"))]),
213            Row::new(vec![SqlValue::Integer(4), SqlValue::Null]),
214        ];
215
216        // Column 0 (integers)
217        let stats = ColumnStatistics::compute(&rows, 0);
218        assert_eq!(stats.n_distinct, 4);
219        assert_eq!(stats.null_count, 0);
220        assert_eq!(stats.min_value, Some(SqlValue::Integer(1)));
221        assert_eq!(stats.max_value, Some(SqlValue::Integer(4)));
222
223        // Column 1 (text with duplicate 'a')
224        let stats = ColumnStatistics::compute(&rows, 1);
225        assert_eq!(stats.n_distinct, 2); // 'a' and 'b'
226        assert_eq!(stats.null_count, 1);
227
228        // Most common value should be 'a' (2/3 = 66.7%)
229        assert_eq!(stats.most_common_values.len(), 2);
230        assert_eq!(stats.most_common_values[0].0, SqlValue::Varchar(arcstr::ArcStr::from("a")));
231        assert!((stats.most_common_values[0].1 - 0.667).abs() < 0.01);
232    }
233
234    #[test]
235    fn test_selectivity_estimation() {
236        let rows = vec![
237            Row::new(vec![SqlValue::Integer(1)]),
238            Row::new(vec![SqlValue::Integer(2)]),
239            Row::new(vec![SqlValue::Integer(3)]),
240            Row::new(vec![SqlValue::Integer(4)]),
241            Row::new(vec![SqlValue::Integer(5)]),
242        ];
243
244        let stats = ColumnStatistics::compute(&rows, 0);
245
246        // Equality selectivity (uniform distribution)
247        let sel = stats.estimate_eq_selectivity(&SqlValue::Integer(3));
248        assert!((sel - 0.2).abs() < 0.01); // 1/5 = 20%
249
250        // NULL selectivity
251        let sel = stats.estimate_eq_selectivity(&SqlValue::Null);
252        assert_eq!(sel, 0.0);
253    }
254}