vegafusion_runtime/transform/
extent.rs

1use crate::expression::compiler::config::CompilationConfig;
2use crate::transform::TransformTrait;
3use async_trait::async_trait;
4
5use crate::data::util::DataFrameUtils;
6use datafusion::arrow::array::RecordBatch;
7use datafusion::prelude::DataFrame;
8use datafusion_common::utils::SingleRowListArrayBuilder;
9use datafusion_common::{DFSchema, ScalarValue};
10use datafusion_expr::Expr;
11use datafusion_functions_aggregate::expr_fn::{max, min};
12use std::sync::Arc;
13use vegafusion_common::column::unescaped_col;
14use vegafusion_common::datatypes::to_numeric;
15use vegafusion_common::error::{Result, ResultWithContext};
16use vegafusion_core::proto::gen::transforms::Extent;
17use vegafusion_core::task_graph::task_value::TaskValue;
18
19#[async_trait]
20impl TransformTrait for Extent {
21    async fn eval(
22        &self,
23        sql_df: DataFrame,
24        _config: &CompilationConfig,
25    ) -> Result<(DataFrame, Vec<TaskValue>)> {
26        let output_values = if self.signal.is_some() {
27            let (min_expr, max_expr) = min_max_exprs(self.field.as_str(), sql_df.schema())?;
28
29            let extent_df = sql_df
30                .clone()
31                .aggregate(Vec::new(), vec![min_expr, max_expr])?;
32
33            // Eval to single row dataframe and extract scalar values
34            let result_batch = extent_df.collect_flat().await?;
35            let extent_list = extract_extent_list(&result_batch)?;
36            vec![extent_list]
37        } else {
38            Vec::new()
39        };
40
41        Ok((sql_df, output_values))
42    }
43}
44
45fn min_max_exprs(field: &str, schema: &DFSchema) -> Result<(Expr, Expr)> {
46    let field_col = unescaped_col(field);
47    let min_expr = min(to_numeric(field_col.clone(), schema)?).alias("__min_val");
48    let max_expr = max(to_numeric(field_col, schema)?).alias("__max_val");
49    Ok((min_expr, max_expr))
50}
51
52fn extract_extent_list(batch: &RecordBatch) -> Result<TaskValue> {
53    let min_val_array = batch
54        .column_by_name("__min_val")
55        .with_context(|| "No column named __min_val".to_string())?;
56    let max_val_array = batch
57        .column_by_name("__max_val")
58        .with_context(|| "No column named __max_val".to_string())?;
59
60    let min_val_scalar = ScalarValue::try_from_array(min_val_array, 0).unwrap();
61    let max_val_scalar = ScalarValue::try_from_array(max_val_array, 0).unwrap();
62
63    // Build two-element list of the extents
64    let extent_list = TaskValue::Scalar(ScalarValue::List(Arc::new(
65        SingleRowListArrayBuilder::new(ScalarValue::iter_to_array(vec![
66            min_val_scalar,
67            max_val_scalar,
68        ])?)
69        .with_nullable(true)
70        .build_list_array(),
71    )));
72    Ok(extent_list)
73}