vibesql_storage/statistics/
column.rs1use std::collections::HashMap;
4
5use vibesql_types::SqlValue;
6
7use super::histogram::{BucketStrategy, Histogram};
8
9#[derive(Debug, Clone)]
11pub struct ColumnStatistics {
12 pub n_distinct: usize,
14
15 pub null_count: usize,
17
18 pub min_value: Option<SqlValue>,
20
21 pub max_value: Option<SqlValue>,
23
24 pub most_common_values: Vec<(SqlValue, f64)>,
26
27 pub histogram: Option<Histogram>,
29}
30
31impl ColumnStatistics {
32 pub fn compute(rows: &[crate::Row], column_idx: usize) -> Self {
34 Self::compute_with_histogram(rows, column_idx, false, 100, BucketStrategy::EqualDepth)
35 }
36
37 pub fn compute_with_histogram(
46 rows: &[crate::Row],
47 column_idx: usize,
48 enable_histogram: bool,
49 num_buckets: usize,
50 bucket_strategy: BucketStrategy,
51 ) -> Self {
52 let mut distinct_values = std::collections::HashSet::new();
53 let mut null_count = 0;
54 let mut min_value: Option<SqlValue> = None;
55 let mut max_value: Option<SqlValue> = None;
56 let mut value_counts: HashMap<SqlValue, usize> = HashMap::new();
57 let mut non_null_values = Vec::new();
58
59 for row in rows {
60 if column_idx >= row.values.len() {
61 continue;
62 }
63
64 let value = &row.values[column_idx];
65
66 if value.is_null() {
67 null_count += 1;
68 continue;
69 }
70
71 distinct_values.insert(value.clone());
72 *value_counts.entry(value.clone()).or_insert(0) += 1;
73 non_null_values.push(value.clone());
74
75 match (&min_value, &max_value) {
77 (None, None) => {
78 min_value = Some(value.clone());
79 max_value = Some(value.clone());
80 }
81 (Some(min), Some(max)) => {
82 if value < min {
83 min_value = Some(value.clone());
84 }
85 if value > max {
86 max_value = Some(value.clone());
87 }
88 }
89 _ => unreachable!(),
90 }
91 }
92
93 let total_non_null = rows.len() - null_count;
95 let mut mcvs: Vec<_> = value_counts
96 .into_iter()
97 .map(|(val, count)| {
98 let frequency =
99 if total_non_null > 0 { count as f64 / total_non_null as f64 } else { 0.0 };
100 (val, frequency)
101 })
102 .collect();
103 mcvs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
104 mcvs.truncate(10);
105
106 let histogram = if enable_histogram && non_null_values.len() > 10 {
108 Some(Histogram::create(&non_null_values, num_buckets, bucket_strategy))
109 } else {
110 None
111 };
112
113 ColumnStatistics {
114 n_distinct: distinct_values.len(),
115 null_count,
116 min_value,
117 max_value,
118 most_common_values: mcvs,
119 histogram,
120 }
121 }
122
123 pub fn estimate_eq_selectivity(&self, value: &SqlValue) -> f64 {
128 if value.is_null() {
129 return 0.0; }
131
132 if let Some((_, freq)) = self.most_common_values.iter().find(|(v, _)| v == value) {
134 return *freq;
135 }
136
137 if let Some(ref histogram) = self.histogram {
139 return histogram.estimate_equality_selectivity(value);
140 }
141
142 if self.n_distinct > 0 {
144 1.0 / (self.n_distinct as f64)
145 } else {
146 0.0
147 }
148 }
149
150 pub fn estimate_ne_selectivity(&self, value: &SqlValue) -> f64 {
152 1.0 - self.estimate_eq_selectivity(value)
153 }
154
155 pub fn estimate_range_selectivity(&self, value: &SqlValue, operator: &str) -> f64 {
160 if let Some(ref histogram) = self.histogram {
162 return histogram.estimate_range_selectivity(operator, value);
163 }
164
165 match (operator, &self.min_value, &self.max_value) {
167 (">", Some(min), Some(max)) | (">=", Some(min), Some(max)) => {
168 if value < min {
169 return 1.0; }
171 if value >= max {
172 return if operator == ">=" && value == max {
173 1.0 / (self.n_distinct as f64)
174 } else {
175 0.0
176 };
177 }
178 0.33
180 }
181 ("<", Some(min), Some(max)) | ("<=", Some(min), Some(max)) => {
182 if value > max {
183 return 1.0;
184 }
185 if value <= min {
186 return if operator == "<=" && value == min {
187 1.0 / (self.n_distinct as f64)
188 } else {
189 0.0
190 };
191 }
192 0.33
194 }
195 _ => 0.33, }
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use vibesql_types::SqlValue;
203
204 use super::*;
205 use crate::Row;
206
207 #[test]
208 fn test_column_statistics_basic() {
209 let rows = vec![
210 Row::new(vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("a"))]),
211 Row::new(vec![SqlValue::Integer(2), SqlValue::Varchar(arcstr::ArcStr::from("b"))]),
212 Row::new(vec![SqlValue::Integer(3), SqlValue::Varchar(arcstr::ArcStr::from("a"))]),
213 Row::new(vec![SqlValue::Integer(4), SqlValue::Null]),
214 ];
215
216 let stats = ColumnStatistics::compute(&rows, 0);
218 assert_eq!(stats.n_distinct, 4);
219 assert_eq!(stats.null_count, 0);
220 assert_eq!(stats.min_value, Some(SqlValue::Integer(1)));
221 assert_eq!(stats.max_value, Some(SqlValue::Integer(4)));
222
223 let stats = ColumnStatistics::compute(&rows, 1);
225 assert_eq!(stats.n_distinct, 2); assert_eq!(stats.null_count, 1);
227
228 assert_eq!(stats.most_common_values.len(), 2);
230 assert_eq!(stats.most_common_values[0].0, SqlValue::Varchar(arcstr::ArcStr::from("a")));
231 assert!((stats.most_common_values[0].1 - 0.667).abs() < 0.01);
232 }
233
234 #[test]
235 fn test_selectivity_estimation() {
236 let rows = vec![
237 Row::new(vec![SqlValue::Integer(1)]),
238 Row::new(vec![SqlValue::Integer(2)]),
239 Row::new(vec![SqlValue::Integer(3)]),
240 Row::new(vec![SqlValue::Integer(4)]),
241 Row::new(vec![SqlValue::Integer(5)]),
242 ];
243
244 let stats = ColumnStatistics::compute(&rows, 0);
245
246 let sel = stats.estimate_eq_selectivity(&SqlValue::Integer(3));
248 assert!((sel - 0.2).abs() < 0.01); let sel = stats.estimate_eq_selectivity(&SqlValue::Null);
252 assert_eq!(sel, 0.0);
253 }
254}