vegafusion_runtime/transform/
extent.rs1use crate::expression::compiler::config::CompilationConfig;
2use crate::transform::TransformTrait;
3use async_trait::async_trait;
4
5use crate::data::util::DataFrameUtils;
6use datafusion::arrow::array::RecordBatch;
7use datafusion::prelude::DataFrame;
8use datafusion_common::utils::SingleRowListArrayBuilder;
9use datafusion_common::{DFSchema, ScalarValue};
10use datafusion_expr::Expr;
11use datafusion_functions_aggregate::expr_fn::{max, min};
12use std::sync::Arc;
13use vegafusion_common::column::unescaped_col;
14use vegafusion_common::datatypes::to_numeric;
15use vegafusion_common::error::{Result, ResultWithContext};
16use vegafusion_core::proto::gen::transforms::Extent;
17use vegafusion_core::task_graph::task_value::TaskValue;
18
19#[async_trait]
20impl TransformTrait for Extent {
21 async fn eval(
22 &self,
23 sql_df: DataFrame,
24 _config: &CompilationConfig,
25 ) -> Result<(DataFrame, Vec<TaskValue>)> {
26 let output_values = if self.signal.is_some() {
27 let (min_expr, max_expr) = min_max_exprs(self.field.as_str(), sql_df.schema())?;
28
29 let extent_df = sql_df
30 .clone()
31 .aggregate(Vec::new(), vec![min_expr, max_expr])?;
32
33 let result_batch = extent_df.collect_flat().await?;
35 let extent_list = extract_extent_list(&result_batch)?;
36 vec![extent_list]
37 } else {
38 Vec::new()
39 };
40
41 Ok((sql_df, output_values))
42 }
43}
44
45fn min_max_exprs(field: &str, schema: &DFSchema) -> Result<(Expr, Expr)> {
46 let field_col = unescaped_col(field);
47 let min_expr = min(to_numeric(field_col.clone(), schema)?).alias("__min_val");
48 let max_expr = max(to_numeric(field_col, schema)?).alias("__max_val");
49 Ok((min_expr, max_expr))
50}
51
52fn extract_extent_list(batch: &RecordBatch) -> Result<TaskValue> {
53 let min_val_array = batch
54 .column_by_name("__min_val")
55 .with_context(|| "No column named __min_val".to_string())?;
56 let max_val_array = batch
57 .column_by_name("__max_val")
58 .with_context(|| "No column named __max_val".to_string())?;
59
60 let min_val_scalar = ScalarValue::try_from_array(min_val_array, 0).unwrap();
61 let max_val_scalar = ScalarValue::try_from_array(max_val_array, 0).unwrap();
62
63 let extent_list = TaskValue::Scalar(ScalarValue::List(Arc::new(
65 SingleRowListArrayBuilder::new(ScalarValue::iter_to_array(vec![
66 min_val_scalar,
67 max_val_scalar,
68 ])?)
69 .with_nullable(true)
70 .build_list_array(),
71 )));
72 Ok(extent_list)
73}