vegafusion_runtime/transform/
impute.rs

1use crate::expression::compiler::config::CompilationConfig;
2
3use crate::data::util::DataFrameUtils;
4use crate::expression::compiler::utils::ExprHelpers;
5use crate::transform::TransformTrait;
6use async_trait::async_trait;
7use datafusion::prelude::DataFrame;
8use datafusion_common::{JoinType, ScalarValue};
9use datafusion_expr::{
10    expr, expr::WindowFunctionParams, lit, when, Expr, SortExpr, WindowFrame,
11    WindowFunctionDefinition,
12};
13// Remove coalesce import as we'll use when/otherwise instead
14use datafusion_functions_aggregate::expr_fn::min;
15use datafusion_functions_window::row_number::RowNumber;
16use itertools::Itertools;
17use sqlparser::ast::NullTreatment;
18use std::sync::Arc;
19use vegafusion_common::column::{flat_col, relation_col};
20use vegafusion_common::data::scalar::ScalarValueHelpers;
21use vegafusion_common::data::ORDER_COL;
22use vegafusion_common::error::{Result, ResultWithContext};
23use vegafusion_common::escape::unescape_field;
24use vegafusion_core::proto::gen::transforms::Impute;
25use vegafusion_core::task_graph::task_value::TaskValue;
26
27#[async_trait]
28impl TransformTrait for Impute {
29    async fn eval(
30        &self,
31        dataframe: DataFrame,
32        _config: &CompilationConfig,
33    ) -> Result<(DataFrame, Vec<TaskValue>)> {
34        // Create ScalarValue used to fill in null values
35        let json_value: serde_json::Value = serde_json::from_str(
36            &self
37                .value_json
38                .clone()
39                .unwrap_or_else(|| "null".to_string()),
40        )?;
41
42        // JSON numbers are always interpreted as floats, but if the value is an integer we'd
43        // like the fill value to be an integer as well to avoid converting an integer input
44        // column to floats
45        let value = if json_value.is_null() {
46            ScalarValue::Float64(None)
47        } else if json_value.is_i64() {
48            ScalarValue::from(json_value.as_i64().unwrap())
49        } else if json_value.is_f64() && json_value.as_f64().unwrap().fract() == 0.0 {
50            ScalarValue::from(json_value.as_f64().unwrap() as i64)
51        } else {
52            ScalarValue::from_json(&json_value)?
53        };
54
55        // Take unique groupby fields (in case there are duplicates)
56        let groupby = self
57            .groupby
58            .clone()
59            .into_iter()
60            .unique()
61            .collect::<Vec<_>>();
62
63        // Unescape field, key, and groupby fields
64        let field = unescape_field(&self.field);
65        let key = unescape_field(&self.key);
66        let groupby: Vec<_> = groupby.iter().map(|f| unescape_field(f)).collect();
67
68        let schema = dataframe.schema();
69        let (_, field_field) = schema
70            .inner()
71            .column_with_name(&field)
72            .with_context(|| format!("No field named {}", field))?;
73        let field_type = field_field.data_type();
74
75        if groupby.is_empty() {
76            // Value replacement for field with no group_by fields specified is equivalent to replacing
77            // null values of that column with the fill value
78            let select_columns = schema
79                .fields()
80                .iter()
81                .map(|f| {
82                    let col_name = f.name();
83                    Ok(if col_name == &field {
84                        let casted_value = lit(value.clone()).try_cast_to(field_type, schema)?;
85                        when(flat_col(&field).is_null(), casted_value)
86                            .otherwise(flat_col(&field))?
87                            .alias(col_name)
88                    } else {
89                        flat_col(col_name)
90                    })
91                })
92                .collect::<Result<Vec<_>>>()?;
93
94            Ok((dataframe.select(select_columns)?, Vec::new()))
95        } else {
96            // First step is to build up a new DataFrame that contains the all possible combinations
97
98            // Build some internal columns for intermediate ordering
99            let order_col = flat_col(ORDER_COL);
100            let order_key = format!("{ORDER_COL}_key");
101            let order_key_col = flat_col(&order_key);
102            let order_group = format!("{ORDER_COL}_groups");
103            let order_group_col = flat_col(&order_group);
104
105            // Create DataFrame with unique key values, and an internal ordering column
106            let key_col = flat_col(&key);
107            let key_df = dataframe
108                .clone()
109                .filter(key_col.clone().is_not_null())?
110                .aggregate_mixed(
111                    vec![key_col.clone()],
112                    vec![min(order_col.clone()).alias(&order_key)],
113                )?;
114
115            // Create DataFrame with unique combinations of group_by values, with an
116            // internal ordering col
117            let group_cols = groupby.iter().map(|c| flat_col(c)).collect::<Vec<_>>();
118
119            let groups_df = dataframe
120                .clone()
121                .aggregate_mixed(group_cols, vec![min(order_col.clone()).alias(&order_group)])?;
122
123            // Build join conditions
124            let mut on_exprs = groupby
125                .iter()
126                .map(|c| relation_col(c, "lhs").eq(relation_col(c, "rhs")))
127                .collect::<Vec<_>>();
128            on_exprs.push(relation_col(&key, "lhs").eq(relation_col(&key, "rhs")));
129
130            // Perform cross join by using a dummy always-true condition
131            // This is needed because empty join conditions are not allowed
132            let pre_ordered_df = key_df
133                .join_on(groups_df, JoinType::Inner, vec![lit(true)])?
134                .alias("lhs")?
135                .join_on(dataframe.clone().alias("rhs")?, JoinType::Left, on_exprs)?;
136
137            // Build final selection that fills in missing values and adds ordering column
138            let mut final_selections = Vec::new();
139            for field_index in 0..schema.fields().len() {
140                let (_, f) = schema.qualified_field(field_index);
141
142                if f.name().starts_with(ORDER_COL) {
143                    // Skip all order cols
144                    continue;
145                } else if f.name() == &field {
146                    // Coalesce to fill in null values in field
147                    let casted_value = lit(value.clone()).try_cast_to(field_type, schema)?;
148                    final_selections.push(
149                        when(flat_col(&field).is_null(), casted_value)
150                            .otherwise(flat_col(&field))?
151                            .alias(f.name()),
152                    );
153                } else {
154                    // Keep other columns
155                    if f.name() == &key || groupby.contains(f.name()) {
156                        // Pull key and groupby columns from the "lhs" table (which won't have nulls
157                        // introduced by the left join)
158                        final_selections.push(relation_col(f.name(), "lhs"));
159                    } else {
160                        // Pull all other columns from the rhs table
161                        final_selections.push(relation_col(f.name(), "rhs"));
162                    }
163                }
164            }
165
166            let final_order_expr = Expr::WindowFunction(Box::new(expr::WindowFunction {
167                fun: WindowFunctionDefinition::WindowUDF(Arc::new(RowNumber::new().into())),
168                params: WindowFunctionParams {
169                    args: vec![],
170                    partition_by: vec![],
171                    order_by: vec![
172                        // Sort first by the original row order, pushing imputed rows to the end
173                        SortExpr::new(order_col.clone(), true, false),
174                        // Sort imputed rows by first row that resides group
175                        // then by first row that matches a key
176                        SortExpr::new(order_group_col, true, true),
177                        SortExpr::new(order_key_col, true, true),
178                    ],
179                    window_frame: WindowFrame::new(Some(true)),
180                    null_treatment: Some(NullTreatment::RespectNulls),
181                },
182            }))
183            .alias(ORDER_COL);
184            final_selections.push(final_order_expr);
185
186            Ok((pre_ordered_df.select(final_selections)?, Default::default()))
187        }
188    }
189}