vegafusion_runtime/transform/
pivot.rs

1use crate::data::util::DataFrameUtils;
2use crate::expression::compiler::config::CompilationConfig;
3use crate::transform::aggregate::make_agg_expr_for_col_expr;
4use crate::transform::TransformTrait;
5use async_trait::async_trait;
6use datafusion::prelude::DataFrame;
7use datafusion_expr::{lit, when};
8use datafusion_functions_aggregate::expr_fn::min;
9use vegafusion_common::arrow::array::StringArray;
10use vegafusion_common::arrow::datatypes::DataType;
11use vegafusion_common::column::{flat_col, unescaped_col};
12use vegafusion_common::data::scalar::ScalarValue;
13use vegafusion_common::data::ORDER_COL;
14use vegafusion_common::datatypes::{cast_to, data_type, is_string_datatype, to_numeric};
15use vegafusion_common::error::{Result, ResultWithContext, VegaFusionError};
16use vegafusion_common::escape::unescape_field;
17use vegafusion_core::proto::gen::transforms::{AggregateOp, Pivot};
18use vegafusion_core::task_graph::task_value::TaskValue;
19
20/// NULL_PLACEHOLDER_NAME is used for sorting to match Vega, where null always comes first for
21/// limit sorting
22const NULL_PLACEHOLDER_NAME: &str = "!!!null";
23
24/// NULL_NAME is the final column name for null columns
25const NULL_NAME: &str = "null";
26
27#[async_trait]
28impl TransformTrait for Pivot {
29    async fn eval(
30        &self,
31        dataframe: DataFrame,
32        _config: &CompilationConfig,
33    ) -> Result<(DataFrame, Vec<TaskValue>)> {
34        // Make sure the pivot column is a string
35        let pivot_dtype = data_type(&unescaped_col(&self.field), dataframe.schema())?;
36        let dataframe = if matches!(pivot_dtype, DataType::Boolean) {
37            // Boolean column type. For consistency with vega, replace 0 with "false" and 1 with "true"
38            let select_exprs: Vec<_> = dataframe
39                .schema()
40                .inner()
41                .fields
42                .iter()
43                .map(|field| {
44                    if field.name() == &unescape_field(&self.field) {
45                        Ok(when(unescaped_col(&self.field).eq(lit(true)), lit("true"))
46                            .when(
47                                unescaped_col(&self.field).is_null(),
48                                lit(NULL_PLACEHOLDER_NAME),
49                            )
50                            .otherwise(lit("false"))
51                            .with_context(|| "Failed to construct Case expression")?
52                            .alias(&self.field))
53                    } else {
54                        Ok(flat_col(field.name()))
55                    }
56                })
57                .collect::<Result<Vec<_>>>()?;
58            dataframe.select(select_exprs)?
59        } else if !is_string_datatype(&pivot_dtype) {
60            // Column type is not string, so cast values to strings
61            let select_exprs: Vec<_> = dataframe
62                .schema()
63                .inner()
64                .fields
65                .iter()
66                .map(|field| {
67                    if field.name() == &unescape_field(&self.field) {
68                        Ok(when(
69                            unescaped_col(&self.field).is_null(),
70                            lit(NULL_PLACEHOLDER_NAME),
71                        )
72                        .otherwise(cast_to(
73                            unescaped_col(&self.field),
74                            &DataType::Utf8,
75                            dataframe.schema(),
76                        )?)?
77                        .alias(&self.field))
78                    } else {
79                        Ok(flat_col(field.name()))
80                    }
81                })
82                .collect::<Result<Vec<_>>>()?;
83            dataframe.select(select_exprs)?
84        } else {
85            // Column type is string, just replace NULL with "null"
86            let select_exprs: Vec<_> = dataframe
87                .schema()
88                .inner()
89                .fields
90                .iter()
91                .map(|field| {
92                    if field.name() == &unescape_field(&self.field) {
93                        let field_col = unescaped_col(&self.field);
94                        Ok(
95                            when(field_col.clone().is_null(), lit(NULL_PLACEHOLDER_NAME))
96                                .when(field_col.clone().eq(lit("")), lit(" "))
97                                .otherwise(field_col)?
98                                .alias(&self.field),
99                        )
100                    } else {
101                        Ok(flat_col(field.name()))
102                    }
103                })
104                .collect::<Result<Vec<_>>>()?;
105            dataframe.select(select_exprs)?
106        };
107
108        pivot_case(self, dataframe).await
109    }
110}
111
112async fn extract_sorted_pivot_values(tx: &Pivot, dataframe: DataFrame) -> Result<Vec<String>> {
113    let agg_query = dataframe.aggregate_mixed(vec![unescaped_col(&tx.field)], vec![])?;
114
115    let limit = match tx.limit {
116        None | Some(0) => None,
117        Some(i) => Some(i as usize),
118    };
119
120    let sorted_query = agg_query
121        .sort(vec![unescaped_col(&tx.field).sort(true, false)])?
122        .limit(0, limit)?;
123
124    let pivot_batch = sorted_query.collect_flat().await?;
125    let pivot_array = pivot_batch
126        .column_by_name(&tx.field)
127        .with_context(|| format!("No column named {}", tx.field))?;
128    let pivot_array = pivot_array
129        .as_any()
130        .downcast_ref::<StringArray>()
131        .with_context(|| "Failed to downcast pivot column to String")?;
132    let pivot_vec: Vec<_> = pivot_array
133        .iter()
134        .filter_map(|val| val.map(|s| s.to_string()))
135        .collect();
136    Ok(pivot_vec)
137}
138
139async fn pivot_case(tx: &Pivot, dataframe: DataFrame) -> Result<(DataFrame, Vec<TaskValue>)> {
140    let pivot_vec = extract_sorted_pivot_values(tx, dataframe.clone()).await?;
141
142    if pivot_vec.is_empty() {
143        return Err(VegaFusionError::internal("Unexpected empty pivot dataset"));
144    }
145
146    let schema = dataframe.schema();
147
148    // Process aggregate operation
149    let agg_op: AggregateOp = tx
150        .op
151        .map(|op_code| AggregateOp::try_from(op_code).unwrap())
152        .unwrap_or(AggregateOp::Sum);
153
154    // Build vector of aggregates
155    let mut agg_exprs: Vec<_> = Vec::new();
156
157    for pivot_val in pivot_vec.iter() {
158        let predicate_expr = unescaped_col(&tx.field).eq(lit(pivot_val.as_str()));
159        let value_expr = to_numeric(unescaped_col(tx.value.as_str()), schema)?;
160        let agg_col = when(predicate_expr, value_expr).otherwise(lit(ScalarValue::Null))?;
161
162        let agg_expr = make_agg_expr_for_col_expr(agg_col, &agg_op, schema)?;
163
164        // Compute pivot column name, replacing null placeholder with "null"
165        let col_name = if pivot_val == NULL_PLACEHOLDER_NAME {
166            NULL_NAME
167        } else {
168            pivot_val.as_str()
169        };
170        let agg_expr = agg_expr.alias(col_name);
171
172        agg_exprs.push(agg_expr);
173    }
174
175    // Insert ordering aggregate
176    agg_exprs.insert(0, min(flat_col(ORDER_COL)).alias(ORDER_COL));
177
178    // Build vector of groupby expressions
179    let group_expr: Vec<_> = tx.groupby.iter().map(|c| unescaped_col(c)).collect();
180
181    let pivoted = dataframe.aggregate_mixed(group_expr, agg_exprs)?;
182    Ok((pivoted, Default::default()))
183}