vegafusion_runtime/transform/
bin.rs1use crate::expression::compiler::compile;
2use crate::expression::compiler::config::CompilationConfig;
3use crate::expression::compiler::utils::ExprHelpers;
4use crate::transform::TransformTrait;
5use async_trait::async_trait;
6
7use datafusion_expr::lit;
8
9use datafusion::prelude::DataFrame;
10use datafusion_common::scalar::ScalarValue;
11use datafusion_common::utils::SingleRowListArrayBuilder;
12use datafusion_common::DFSchema;
13use datafusion_expr::when;
14use datafusion_functions::expr_fn::{abs, floor};
15use float_cmp::approx_eq;
16use std::ops::{Add, Div, Mul, Sub};
17use std::sync::Arc;
18use vegafusion_common::column::{flat_col, unescaped_col};
19use vegafusion_common::data::scalar::ScalarValueHelpers;
20use vegafusion_common::datatypes::to_numeric;
21use vegafusion_core::error::{Result, VegaFusionError};
22use vegafusion_core::proto::gen::transforms::Bin;
23use vegafusion_core::task_graph::task_value::TaskValue;
24
25#[async_trait]
26impl TransformTrait for Bin {
27 async fn eval(
28 &self,
29 sql_df: DataFrame,
30 config: &CompilationConfig,
31 ) -> Result<(DataFrame, Vec<TaskValue>)> {
32 let schema = sql_df.schema().clone();
33
34 let params = calculate_bin_params(self, &schema, config)?;
36
37 let BinParams {
38 start,
39 stop,
40 step,
41 n,
42 } = params;
43 let bin_starts: Vec<f64> = (0..n).map(|i| start + step * i as f64).collect();
44 let last_stop = *bin_starts.last().unwrap() + step;
45
46 let output_value = compute_output_value(self, start, stop, step);
48
49 let numeric_field = to_numeric(unescaped_col(&self.field), sql_df.schema())?;
50
51 let bin_index_name = "__bin_index";
53 let bin_index =
54 floor((numeric_field.clone().sub(lit(start)).div(lit(step))).add(lit(1.0e-14)))
55 .alias(bin_index_name);
56 let sql_df = sql_df.select(vec![datafusion_expr::expr_fn::wildcard(), bin_index.into()])?;
57
58 let bin_start = (flat_col(bin_index_name).mul(lit(step))).add(lit(start));
60 let bin_start_name = self.alias_0.clone().unwrap_or_else(|| "bin0".to_string());
61
62 let inf = lit(f64::INFINITY);
63 let neg_inf = lit(f64::NEG_INFINITY);
64 let eps = lit(1.0e-14);
65
66 let bin_start = when(flat_col(bin_index_name).lt(lit(0.0)), neg_inf)
67 .when(
68 abs(numeric_field.sub(lit(last_stop)))
69 .lt(eps)
70 .and(flat_col(bin_index_name).eq(lit(n))),
71 flat_col(bin_index_name)
72 .sub(lit(1))
73 .mul(lit(step))
74 .add(lit(start)),
75 )
76 .when(flat_col(bin_index_name).gt_eq(lit(n)), inf)
77 .otherwise(bin_start)?
78 .alias(&bin_start_name);
79
80 let mut select_exprs = sql_df
81 .schema()
82 .inner()
83 .fields
84 .iter()
85 .filter_map(|field| {
86 if field.name() == &bin_start_name {
87 None
88 } else {
89 Some(flat_col(field.name()))
90 }
91 })
92 .collect::<Vec<_>>();
93 select_exprs.push(bin_start);
94
95 let sql_df = sql_df.select(select_exprs)?;
96
97 let bin_end_name = self.alias_1.clone().unwrap_or_else(|| "bin1".to_string());
99 let bin_end = (flat_col(&bin_start_name) + lit(step)).alias(&bin_end_name);
100
101 let mut select_exprs = schema
103 .fields()
104 .iter()
105 .filter_map(|field| {
106 let name = field.name();
107 if name == &bin_start_name || name == &bin_end_name {
108 None
109 } else {
110 Some(flat_col(name))
111 }
112 })
113 .collect::<Vec<_>>();
114 select_exprs.push(flat_col(&bin_start_name));
115 select_exprs.push(bin_end);
116
117 let sql_df = sql_df.select(select_exprs)?;
118
119 Ok((sql_df, output_value.into_iter().collect()))
120 }
121}
122
123fn compute_output_value(bin_tx: &Bin, start: f64, stop: f64, step: f64) -> Option<TaskValue> {
124 let mut fname = bin_tx.field.clone();
125 fname.insert_str(0, "bin_");
126 let fields = ScalarValue::List(Arc::new(
127 SingleRowListArrayBuilder::new(
128 ScalarValue::iter_to_array(vec![ScalarValue::from(bin_tx.field.as_str())]).ok()?,
129 )
130 .with_nullable(true)
131 .build_list_array(),
132 ));
133
134 if bin_tx.signal.is_some() {
135 Some(TaskValue::Scalar(ScalarValue::from(vec![
136 ("fields", fields),
137 ("fname", ScalarValue::from(fname.as_str())),
138 ("start", ScalarValue::from(start)),
139 ("step", ScalarValue::from(step)),
140 ("stop", ScalarValue::from(stop)),
141 ])))
142 } else {
143 None
144 }
145}
146
147#[derive(Clone, Debug)]
148pub struct BinParams {
149 pub start: f64,
150 pub stop: f64,
151 pub step: f64,
152 pub n: i32,
153}
154
155pub fn calculate_bin_params(
156 tx: &Bin,
157 schema: &DFSchema,
158 config: &CompilationConfig,
159) -> Result<BinParams> {
160 let extent_expr = compile(tx.extent.as_ref().unwrap(), config, Some(schema))?;
162 let extent_scalar = extent_expr.eval_to_scalar()?;
163
164 let extent = extent_scalar.to_f64x2().unwrap_or([0.0, 0.0]);
165
166 let [min_, max_] = extent;
167 if min_ > max_ {
168 return Err(VegaFusionError::specification(format!(
169 "extent[1] must be greater than extent[0]: Received {extent:?}"
170 )));
171 }
172
173 let mut span = if !approx_eq!(f64, min_, max_) {
175 max_ - min_
176 } else if !approx_eq!(f64, min_, 0.0) {
177 min_.abs()
178 } else {
179 1.0
180 };
181
182 if let Some(span_expression) = &tx.span {
184 let span_expr = compile(span_expression, config, Some(schema))?;
185 let span_scalar = span_expr.eval_to_scalar()?;
186 if let Ok(span_f64) = span_scalar.to_f64() {
187 if span_f64 > 0.0 {
188 span = span_f64;
189 }
190 }
191 }
192
193 let maxbins = compile(tx.maxbins.as_ref().unwrap(), config, Some(schema))?
194 .eval_to_scalar()?
195 .to_f64()?;
196
197 let logb = tx.base.ln();
198
199 let step = if let Some(step) = tx.step {
200 step
202 } else if !tx.steps.is_empty() {
203 let min_step_size = span / maxbins;
206 let valid_steps: Vec<_> = tx
207 .steps
208 .clone()
209 .into_iter()
210 .filter(|s| *s > min_step_size)
211 .collect();
212 *valid_steps
213 .first()
214 .unwrap_or_else(|| tx.steps.last().unwrap())
215 } else {
216 let level = (maxbins.ln() / logb).ceil();
218 let minstep = tx.minstep;
219 let mut step = minstep.max(tx.base.powf((span.ln() / logb).round() - level));
220
221 while (span / step).ceil() > maxbins {
223 step *= tx.base;
224 }
225
226 for div in &tx.divide {
228 let v = step / div;
229 if v >= minstep && span / v <= maxbins {
230 step = v
231 }
232 }
233 step
234 };
235
236 let v = step.ln();
238 let precision = if v >= 0.0 {
239 0.0
240 } else {
241 (-v / logb).floor() + 1.0
242 };
243 let eps = tx.base.powf(-precision - 1.0);
244 let (min_, max_) = if tx.nice {
245 let v = (min_ / step + eps).floor() * step;
246 let min_ = if min_ < v { v - step } else { v };
247 let max_ = (max_ / step).ceil() * step;
248 (min_, max_)
249 } else {
250 (min_, max_)
251 };
252
253 let start = min_;
255 let stop = if !approx_eq!(f64, max_, min_) {
256 max_
257 } else {
258 min_ + step
259 };
260
261 let (start, stop) = if let Some(anchor) = tx.anchor {
263 let shift = anchor - (start + step * ((anchor - start) / step).floor());
264 (start + shift, stop + shift)
265 } else {
266 (start, stop)
267 };
268
269 Ok(BinParams {
270 start,
271 stop,
272 step,
273 n: ((stop - start) / step).ceil() as i32,
274 })
275}