vegafusion_runtime/transform/
bin.rs

1use crate::expression::compiler::compile;
2use crate::expression::compiler::config::CompilationConfig;
3use crate::expression::compiler::utils::ExprHelpers;
4use crate::transform::TransformTrait;
5use async_trait::async_trait;
6
7use datafusion_expr::lit;
8
9use datafusion::prelude::DataFrame;
10use datafusion_common::scalar::ScalarValue;
11use datafusion_common::utils::SingleRowListArrayBuilder;
12use datafusion_common::DFSchema;
13use datafusion_expr::when;
14use datafusion_functions::expr_fn::{abs, floor};
15use float_cmp::approx_eq;
16use std::ops::{Add, Div, Mul, Sub};
17use std::sync::Arc;
18use vegafusion_common::column::{flat_col, unescaped_col};
19use vegafusion_common::data::scalar::ScalarValueHelpers;
20use vegafusion_common::datatypes::to_numeric;
21use vegafusion_core::error::{Result, VegaFusionError};
22use vegafusion_core::proto::gen::transforms::Bin;
23use vegafusion_core::task_graph::task_value::TaskValue;
24
25#[async_trait]
26impl TransformTrait for Bin {
27    async fn eval(
28        &self,
29        sql_df: DataFrame,
30        config: &CompilationConfig,
31    ) -> Result<(DataFrame, Vec<TaskValue>)> {
32        let schema = sql_df.schema().clone();
33
34        // Compute binning solution
35        let params = calculate_bin_params(self, &schema, config)?;
36
37        let BinParams {
38            start,
39            stop,
40            step,
41            n,
42        } = params;
43        let bin_starts: Vec<f64> = (0..n).map(|i| start + step * i as f64).collect();
44        let last_stop = *bin_starts.last().unwrap() + step;
45
46        // Compute output signal value
47        let output_value = compute_output_value(self, start, stop, step);
48
49        let numeric_field = to_numeric(unescaped_col(&self.field), sql_df.schema())?;
50
51        // Add column with bin index
52        let bin_index_name = "__bin_index";
53        let bin_index =
54            floor((numeric_field.clone().sub(lit(start)).div(lit(step))).add(lit(1.0e-14)))
55                .alias(bin_index_name);
56        let sql_df = sql_df.select(vec![datafusion_expr::expr_fn::wildcard(), bin_index.into()])?;
57
58        // Add column with bin start
59        let bin_start = (flat_col(bin_index_name).mul(lit(step))).add(lit(start));
60        let bin_start_name = self.alias_0.clone().unwrap_or_else(|| "bin0".to_string());
61
62        let inf = lit(f64::INFINITY);
63        let neg_inf = lit(f64::NEG_INFINITY);
64        let eps = lit(1.0e-14);
65
66        let bin_start = when(flat_col(bin_index_name).lt(lit(0.0)), neg_inf)
67            .when(
68                abs(numeric_field.sub(lit(last_stop)))
69                    .lt(eps)
70                    .and(flat_col(bin_index_name).eq(lit(n))),
71                flat_col(bin_index_name)
72                    .sub(lit(1))
73                    .mul(lit(step))
74                    .add(lit(start)),
75            )
76            .when(flat_col(bin_index_name).gt_eq(lit(n)), inf)
77            .otherwise(bin_start)?
78            .alias(&bin_start_name);
79
80        let mut select_exprs = sql_df
81            .schema()
82            .inner()
83            .fields
84            .iter()
85            .filter_map(|field| {
86                if field.name() == &bin_start_name {
87                    None
88                } else {
89                    Some(flat_col(field.name()))
90                }
91            })
92            .collect::<Vec<_>>();
93        select_exprs.push(bin_start);
94
95        let sql_df = sql_df.select(select_exprs)?;
96
97        // Add bin end column
98        let bin_end_name = self.alias_1.clone().unwrap_or_else(|| "bin1".to_string());
99        let bin_end = (flat_col(&bin_start_name) + lit(step)).alias(&bin_end_name);
100
101        // Compute final projection that removes __bin_index column
102        let mut select_exprs = schema
103            .fields()
104            .iter()
105            .filter_map(|field| {
106                let name = field.name();
107                if name == &bin_start_name || name == &bin_end_name {
108                    None
109                } else {
110                    Some(flat_col(name))
111                }
112            })
113            .collect::<Vec<_>>();
114        select_exprs.push(flat_col(&bin_start_name));
115        select_exprs.push(bin_end);
116
117        let sql_df = sql_df.select(select_exprs)?;
118
119        Ok((sql_df, output_value.into_iter().collect()))
120    }
121}
122
123fn compute_output_value(bin_tx: &Bin, start: f64, stop: f64, step: f64) -> Option<TaskValue> {
124    let mut fname = bin_tx.field.clone();
125    fname.insert_str(0, "bin_");
126    let fields = ScalarValue::List(Arc::new(
127        SingleRowListArrayBuilder::new(
128            ScalarValue::iter_to_array(vec![ScalarValue::from(bin_tx.field.as_str())]).ok()?,
129        )
130        .with_nullable(true)
131        .build_list_array(),
132    ));
133
134    if bin_tx.signal.is_some() {
135        Some(TaskValue::Scalar(ScalarValue::from(vec![
136            ("fields", fields),
137            ("fname", ScalarValue::from(fname.as_str())),
138            ("start", ScalarValue::from(start)),
139            ("step", ScalarValue::from(step)),
140            ("stop", ScalarValue::from(stop)),
141        ])))
142    } else {
143        None
144    }
145}
146
147#[derive(Clone, Debug)]
148pub struct BinParams {
149    pub start: f64,
150    pub stop: f64,
151    pub step: f64,
152    pub n: i32,
153}
154
155pub fn calculate_bin_params(
156    tx: &Bin,
157    schema: &DFSchema,
158    config: &CompilationConfig,
159) -> Result<BinParams> {
160    // Evaluate extent
161    let extent_expr = compile(tx.extent.as_ref().unwrap(), config, Some(schema))?;
162    let extent_scalar = extent_expr.eval_to_scalar()?;
163
164    let extent = extent_scalar.to_f64x2().unwrap_or([0.0, 0.0]);
165
166    let [min_, max_] = extent;
167    if min_ > max_ {
168        return Err(VegaFusionError::specification(format!(
169            "extent[1] must be greater than extent[0]: Received {extent:?}"
170        )));
171    }
172
173    // Initialize span to default value
174    let mut span = if !approx_eq!(f64, min_, max_) {
175        max_ - min_
176    } else if !approx_eq!(f64, min_, 0.0) {
177        min_.abs()
178    } else {
179        1.0
180    };
181
182    // Override span with specified value if available
183    if let Some(span_expression) = &tx.span {
184        let span_expr = compile(span_expression, config, Some(schema))?;
185        let span_scalar = span_expr.eval_to_scalar()?;
186        if let Ok(span_f64) = span_scalar.to_f64() {
187            if span_f64 > 0.0 {
188                span = span_f64;
189            }
190        }
191    }
192
193    let maxbins = compile(tx.maxbins.as_ref().unwrap(), config, Some(schema))?
194        .eval_to_scalar()?
195        .to_f64()?;
196
197    let logb = tx.base.ln();
198
199    let step = if let Some(step) = tx.step {
200        // Use provided step as-is
201        step
202    } else if !tx.steps.is_empty() {
203        // If steps is provided, limit step to one of the elements.
204        // Choose the first element of steps that will result in fewer than maxmins
205        let min_step_size = span / maxbins;
206        let valid_steps: Vec<_> = tx
207            .steps
208            .clone()
209            .into_iter()
210            .filter(|s| *s > min_step_size)
211            .collect();
212        *valid_steps
213            .first()
214            .unwrap_or_else(|| tx.steps.last().unwrap())
215    } else {
216        // Otherwise, use span to determine the step size
217        let level = (maxbins.ln() / logb).ceil();
218        let minstep = tx.minstep;
219        let mut step = minstep.max(tx.base.powf((span.ln() / logb).round() - level));
220
221        // increase step size if too many bins
222        while (span / step).ceil() > maxbins {
223            step *= tx.base;
224        }
225
226        // decrease step size if allowed
227        for div in &tx.divide {
228            let v = step / div;
229            if v >= minstep && span / v <= maxbins {
230                step = v
231            }
232        }
233        step
234    };
235
236    // Update precision of min_ and max_
237    let v = step.ln();
238    let precision = if v >= 0.0 {
239        0.0
240    } else {
241        (-v / logb).floor() + 1.0
242    };
243    let eps = tx.base.powf(-precision - 1.0);
244    let (min_, max_) = if tx.nice {
245        let v = (min_ / step + eps).floor() * step;
246        let min_ = if min_ < v { v - step } else { v };
247        let max_ = (max_ / step).ceil() * step;
248        (min_, max_)
249    } else {
250        (min_, max_)
251    };
252
253    // Compute start and stop
254    let start = min_;
255    let stop = if !approx_eq!(f64, max_, min_) {
256        max_
257    } else {
258        min_ + step
259    };
260
261    // Handle anchor
262    let (start, stop) = if let Some(anchor) = tx.anchor {
263        let shift = anchor - (start + step * ((anchor - start) / step).floor());
264        (start + shift, stop + shift)
265    } else {
266        (start, stop)
267    };
268
269    Ok(BinParams {
270        start,
271        stop,
272        step,
273        n: ((stop - start) / step).ceil() as i32,
274    })
275}