vegafusion_runtime/transform/
pivot.rs1use crate::data::util::DataFrameUtils;
2use crate::expression::compiler::config::CompilationConfig;
3use crate::transform::aggregate::make_agg_expr_for_col_expr;
4use crate::transform::TransformTrait;
5use async_trait::async_trait;
6use datafusion::prelude::DataFrame;
7use datafusion_expr::{lit, when};
8use datafusion_functions_aggregate::expr_fn::min;
9use vegafusion_common::arrow::array::StringArray;
10use vegafusion_common::arrow::datatypes::DataType;
11use vegafusion_common::column::{flat_col, unescaped_col};
12use vegafusion_common::data::scalar::ScalarValue;
13use vegafusion_common::data::ORDER_COL;
14use vegafusion_common::datatypes::{cast_to, data_type, is_string_datatype, to_numeric};
15use vegafusion_common::error::{Result, ResultWithContext, VegaFusionError};
16use vegafusion_common::escape::unescape_field;
17use vegafusion_core::proto::gen::transforms::{AggregateOp, Pivot};
18use vegafusion_core::task_graph::task_value::TaskValue;
19
20const NULL_PLACEHOLDER_NAME: &str = "!!!null";
23
24const NULL_NAME: &str = "null";
26
27#[async_trait]
28impl TransformTrait for Pivot {
29 async fn eval(
30 &self,
31 dataframe: DataFrame,
32 _config: &CompilationConfig,
33 ) -> Result<(DataFrame, Vec<TaskValue>)> {
34 let pivot_dtype = data_type(&unescaped_col(&self.field), dataframe.schema())?;
36 let dataframe = if matches!(pivot_dtype, DataType::Boolean) {
37 let select_exprs: Vec<_> = dataframe
39 .schema()
40 .inner()
41 .fields
42 .iter()
43 .map(|field| {
44 if field.name() == &unescape_field(&self.field) {
45 Ok(when(unescaped_col(&self.field).eq(lit(true)), lit("true"))
46 .when(
47 unescaped_col(&self.field).is_null(),
48 lit(NULL_PLACEHOLDER_NAME),
49 )
50 .otherwise(lit("false"))
51 .with_context(|| "Failed to construct Case expression")?
52 .alias(&self.field))
53 } else {
54 Ok(flat_col(field.name()))
55 }
56 })
57 .collect::<Result<Vec<_>>>()?;
58 dataframe.select(select_exprs)?
59 } else if !is_string_datatype(&pivot_dtype) {
60 let select_exprs: Vec<_> = dataframe
62 .schema()
63 .inner()
64 .fields
65 .iter()
66 .map(|field| {
67 if field.name() == &unescape_field(&self.field) {
68 Ok(when(
69 unescaped_col(&self.field).is_null(),
70 lit(NULL_PLACEHOLDER_NAME),
71 )
72 .otherwise(cast_to(
73 unescaped_col(&self.field),
74 &DataType::Utf8,
75 dataframe.schema(),
76 )?)?
77 .alias(&self.field))
78 } else {
79 Ok(flat_col(field.name()))
80 }
81 })
82 .collect::<Result<Vec<_>>>()?;
83 dataframe.select(select_exprs)?
84 } else {
85 let select_exprs: Vec<_> = dataframe
87 .schema()
88 .inner()
89 .fields
90 .iter()
91 .map(|field| {
92 if field.name() == &unescape_field(&self.field) {
93 let field_col = unescaped_col(&self.field);
94 Ok(
95 when(field_col.clone().is_null(), lit(NULL_PLACEHOLDER_NAME))
96 .when(field_col.clone().eq(lit("")), lit(" "))
97 .otherwise(field_col)?
98 .alias(&self.field),
99 )
100 } else {
101 Ok(flat_col(field.name()))
102 }
103 })
104 .collect::<Result<Vec<_>>>()?;
105 dataframe.select(select_exprs)?
106 };
107
108 pivot_case(self, dataframe).await
109 }
110}
111
112async fn extract_sorted_pivot_values(tx: &Pivot, dataframe: DataFrame) -> Result<Vec<String>> {
113 let agg_query = dataframe.aggregate_mixed(vec![unescaped_col(&tx.field)], vec![])?;
114
115 let limit = match tx.limit {
116 None | Some(0) => None,
117 Some(i) => Some(i as usize),
118 };
119
120 let sorted_query = agg_query
121 .sort(vec![unescaped_col(&tx.field).sort(true, false)])?
122 .limit(0, limit)?;
123
124 let pivot_batch = sorted_query.collect_flat().await?;
125 let pivot_array = pivot_batch
126 .column_by_name(&tx.field)
127 .with_context(|| format!("No column named {}", tx.field))?;
128 let pivot_array = pivot_array
129 .as_any()
130 .downcast_ref::<StringArray>()
131 .with_context(|| "Failed to downcast pivot column to String")?;
132 let pivot_vec: Vec<_> = pivot_array
133 .iter()
134 .filter_map(|val| val.map(|s| s.to_string()))
135 .collect();
136 Ok(pivot_vec)
137}
138
139async fn pivot_case(tx: &Pivot, dataframe: DataFrame) -> Result<(DataFrame, Vec<TaskValue>)> {
140 let pivot_vec = extract_sorted_pivot_values(tx, dataframe.clone()).await?;
141
142 if pivot_vec.is_empty() {
143 return Err(VegaFusionError::internal("Unexpected empty pivot dataset"));
144 }
145
146 let schema = dataframe.schema();
147
148 let agg_op: AggregateOp = tx
150 .op
151 .map(|op_code| AggregateOp::try_from(op_code).unwrap())
152 .unwrap_or(AggregateOp::Sum);
153
154 let mut agg_exprs: Vec<_> = Vec::new();
156
157 for pivot_val in pivot_vec.iter() {
158 let predicate_expr = unescaped_col(&tx.field).eq(lit(pivot_val.as_str()));
159 let value_expr = to_numeric(unescaped_col(tx.value.as_str()), schema)?;
160 let agg_col = when(predicate_expr, value_expr).otherwise(lit(ScalarValue::Null))?;
161
162 let agg_expr = make_agg_expr_for_col_expr(agg_col, &agg_op, schema)?;
163
164 let col_name = if pivot_val == NULL_PLACEHOLDER_NAME {
166 NULL_NAME
167 } else {
168 pivot_val.as_str()
169 };
170 let agg_expr = agg_expr.alias(col_name);
171
172 agg_exprs.push(agg_expr);
173 }
174
175 agg_exprs.insert(0, min(flat_col(ORDER_COL)).alias(ORDER_COL));
177
178 let group_expr: Vec<_> = tx.groupby.iter().map(|c| unescaped_col(c)).collect();
180
181 let pivoted = dataframe.aggregate_mixed(group_expr, agg_exprs)?;
182 Ok((pivoted, Default::default()))
183}