vortex_datafusion/vendor/
schema_rewriter.rs

1// SPDX-FileCopyrightText: 2016-2025 Copyright The Apache Software Foundation
2// SPDX-FileCopyrightText: 2025 Copyright the Vortex contributors
3// SPDX-License-Identifier: Apache-2.0
4// SPDX-FileComment: Derived from upstream file datafusion/physical-expr-adapter/src/schema_rewriter.rs at commit e571b49 at https://github.com/apache/datafusion
5// SPDX-FileNotice: https://github.com/apache/datafusion/blob/e571b49e0983892597a8f92e5d1502b17a15b180/NOTICE.txt
6
7#![allow(missing_docs)]
8
9//! Physical expression schema rewriting utilities
10//!
11//! NOTE(aduffy): this is vendored until DF 52 is released, at which point this should
12//!     all be deleted.
13
14use std::sync::Arc;
15
16use datafusion_common::Result;
17use datafusion_common::ScalarValue;
18use datafusion_common::arrow::compute::can_cast_types;
19use datafusion_common::arrow::datatypes::DataType;
20use datafusion_common::arrow::datatypes::FieldRef;
21use datafusion_common::arrow::datatypes::Schema;
22use datafusion_common::arrow::datatypes::SchemaRef;
23use datafusion_common::exec_err;
24use datafusion_common::nested_struct::validate_struct_compatibility;
25use datafusion_common::tree_node::Transformed;
26use datafusion_common::tree_node::TransformedResult;
27use datafusion_common::tree_node::TreeNode;
28use datafusion_functions::core::getfield::GetFieldFunc;
29use datafusion_physical_expr::ScalarFunctionExpr;
30use datafusion_physical_expr::expressions::CastColumnExpr;
31use datafusion_physical_expr::expressions::Column;
32use datafusion_physical_expr::expressions::{self};
33use datafusion_physical_expr_adapter::PhysicalExprAdapter;
34use datafusion_physical_expr_adapter::PhysicalExprAdapterFactory;
35use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
36
37#[derive(Debug, Clone)]
38pub struct DF52PhysicalExprAdapterFactory;
39
40impl PhysicalExprAdapterFactory for DF52PhysicalExprAdapterFactory {
41    fn create(
42        &self,
43        logical_file_schema: SchemaRef,
44        physical_file_schema: SchemaRef,
45    ) -> Arc<dyn PhysicalExprAdapter> {
46        Arc::new(DF52PhysicalExprAdapter {
47            logical_file_schema,
48            physical_file_schema,
49            partition_values: Vec::new(),
50        })
51    }
52}
53
54#[derive(Debug, Clone)]
55pub struct DF52PhysicalExprAdapter {
56    logical_file_schema: SchemaRef,
57    physical_file_schema: SchemaRef,
58    partition_values: Vec<(FieldRef, ScalarValue)>,
59}
60
61impl DF52PhysicalExprAdapter {
62    /// Create a new instance of the default physical expression adapter.
63    ///
64    /// This adapter rewrites expressions to match the physical schema of the file being scanned,
65    /// handling type mismatches and missing columns by filling them with default values.
66    pub fn new(logical_file_schema: SchemaRef, physical_file_schema: SchemaRef) -> Self {
67        Self {
68            logical_file_schema,
69            physical_file_schema,
70            partition_values: Vec::new(),
71        }
72    }
73}
74
75impl PhysicalExprAdapter for DF52PhysicalExprAdapter {
76    fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
77        let rewriter = DefaultPhysicalExprAdapterRewriter {
78            logical_file_schema: &self.logical_file_schema,
79            physical_file_schema: &self.physical_file_schema,
80            partition_fields: &self.partition_values,
81        };
82        expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr)))
83            .data()
84    }
85
86    fn with_partition_values(
87        &self,
88        partition_values: Vec<(FieldRef, ScalarValue)>,
89    ) -> Arc<dyn PhysicalExprAdapter> {
90        Arc::new(DF52PhysicalExprAdapter {
91            partition_values,
92            ..self.clone()
93        })
94    }
95}
96
97struct DefaultPhysicalExprAdapterRewriter<'a> {
98    logical_file_schema: &'a Schema,
99    physical_file_schema: &'a Schema,
100    partition_fields: &'a [(FieldRef, ScalarValue)],
101}
102
103impl<'a> DefaultPhysicalExprAdapterRewriter<'a> {
104    fn rewrite_expr(
105        &self,
106        expr: Arc<dyn PhysicalExpr>,
107    ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
108        if let Some(transformed) = self.try_rewrite_struct_field_access(&expr)? {
109            return Ok(Transformed::yes(transformed));
110        }
111
112        if let Some(column) = expr.as_any().downcast_ref::<Column>() {
113            return self.rewrite_column(Arc::clone(&expr), column);
114        }
115
116        Ok(Transformed::no(expr))
117    }
118
119    /// Attempt to rewrite struct field access expressions to return null if the field does not exist in the physical schema.
120    /// Note that this does *not* handle nested struct fields, only top-level struct field access.
121    /// See <https://github.com/apache/datafusion/issues/17114> for more details.
122    fn try_rewrite_struct_field_access(
123        &self,
124        expr: &Arc<dyn PhysicalExpr>,
125    ) -> Result<Option<Arc<dyn PhysicalExpr>>> {
126        let get_field_expr =
127            match ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(expr.as_ref()) {
128                Some(expr) => expr,
129                None => return Ok(None),
130            };
131
132        let source_expr = match get_field_expr.args().first() {
133            Some(expr) => expr,
134            None => return Ok(None),
135        };
136
137        let field_name_expr = match get_field_expr.args().get(1) {
138            Some(expr) => expr,
139            None => return Ok(None),
140        };
141
142        let lit = match field_name_expr
143            .as_any()
144            .downcast_ref::<expressions::Literal>()
145        {
146            Some(lit) => lit,
147            None => return Ok(None),
148        };
149
150        let field_name = match lit.value().try_as_str().flatten() {
151            Some(name) => name,
152            None => return Ok(None),
153        };
154
155        let column = match source_expr.as_any().downcast_ref::<Column>() {
156            Some(column) => column,
157            None => return Ok(None),
158        };
159
160        let physical_field = match self.physical_file_schema.field_with_name(column.name()) {
161            Ok(field) => field,
162            Err(_) => return Ok(None),
163        };
164
165        let physical_struct_fields = match physical_field.data_type() {
166            DataType::Struct(fields) => fields,
167            _ => return Ok(None),
168        };
169
170        if physical_struct_fields
171            .iter()
172            .any(|f| f.name() == field_name)
173        {
174            return Ok(None);
175        }
176
177        let logical_field = match self.logical_file_schema.field_with_name(column.name()) {
178            Ok(field) => field,
179            Err(_) => return Ok(None),
180        };
181
182        let logical_struct_fields = match logical_field.data_type() {
183            DataType::Struct(fields) => fields,
184            _ => return Ok(None),
185        };
186
187        let logical_struct_field = match logical_struct_fields
188            .iter()
189            .find(|f| f.name() == field_name)
190        {
191            Some(field) => field,
192            None => return Ok(None),
193        };
194
195        let null_value = ScalarValue::Null.cast_to(logical_struct_field.data_type())?;
196        Ok(Some(expressions::lit(null_value)))
197    }
198
199    fn rewrite_column(
200        &self,
201        expr: Arc<dyn PhysicalExpr>,
202        column: &Column,
203    ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
204        // Get the logical field for this column if it exists in the logical schema
205        let logical_field = match self.logical_file_schema.field_with_name(column.name()) {
206            Ok(field) => field,
207            Err(e) => {
208                // If the column is a partition field, we can use the partition value
209                if let Some(partition_value) = self.get_partition_value(column.name()) {
210                    return Ok(Transformed::yes(expressions::lit(partition_value)));
211                }
212                // This can be hit if a custom rewrite injected a reference to a column that doesn't exist in the logical schema.
213                // For example, a pre-computed column that is kept only in the physical schema.
214                // If the column exists in the physical schema, we can still use it.
215                if let Ok(physical_field) = self.physical_file_schema.field_with_name(column.name())
216                {
217                    // If the column exists in the physical schema, we can use it in place of the logical column.
218                    // This is nice to users because if they do a rewrite that results in something like `physical_int32_col = 123u64`
219                    // we'll at least handle the casts for them.
220                    physical_field
221                } else {
222                    // A completely unknown column that doesn't exist in either schema!
223                    // This should probably never be hit unless something upstream broke, but nonetheless it's better
224                    // for us to return a handleable error than to panic / do something unexpected.
225                    return Err(e.into());
226                }
227            }
228        };
229
230        // Check if the column exists in the physical schema
231        let physical_column_index = match self.physical_file_schema.index_of(column.name()) {
232            Ok(index) => index,
233            Err(_) => {
234                if !logical_field.is_nullable() {
235                    return exec_err!(
236                        "Non-nullable column '{}' is missing from the physical schema",
237                        column.name()
238                    );
239                }
240                // If the column is missing from the physical schema fill it in with nulls as `SchemaAdapter` would do.
241                // TODO: do we need to sync this with what the `SchemaAdapter` actually does?
242                // While the default implementation fills in nulls in theory a custom `SchemaAdapter` could do something else!
243                // See https://github.com/apache/datafusion/issues/16527
244                let null_value = ScalarValue::Null.cast_to(logical_field.data_type())?;
245                return Ok(Transformed::yes(expressions::lit(null_value)));
246            }
247        };
248        let physical_field = self.physical_file_schema.field(physical_column_index);
249
250        let column = match (
251            column.index() == physical_column_index,
252            logical_field.data_type() == physical_field.data_type(),
253        ) {
254            // If the column index matches and the data types match, we can use the column as is
255            (true, true) => return Ok(Transformed::no(expr)),
256            // If the indexes or data types do not match, we need to create a new column expression
257            (true, _) => column.clone(),
258            (false, _) => Column::new_with_schema(logical_field.name(), self.physical_file_schema)?,
259        };
260
261        if logical_field.data_type() == physical_field.data_type() {
262            // If the data types match, we can use the column as is
263            return Ok(Transformed::yes(Arc::new(column)));
264        }
265
266        // We need to cast the column to the logical data type
267        // TODO: add optimization to move the cast from the column to literal expressions in the case of `col = 123`
268        // since that's much cheaper to evalaute.
269        // See https://github.com/apache/datafusion/issues/15780#issuecomment-2824716928
270        match (physical_field.data_type(), logical_field.data_type()) {
271            (DataType::Struct(physical_fields), DataType::Struct(logical_fields)) => {
272                validate_struct_compatibility(physical_fields, logical_fields)?;
273            }
274            _ => {
275                let is_compatible =
276                    can_cast_types(physical_field.data_type(), logical_field.data_type());
277                if !is_compatible {
278                    return exec_err!(
279                        "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)",
280                        column.name(),
281                        physical_field.data_type(),
282                        logical_field.data_type()
283                    );
284                }
285            }
286        }
287
288        let cast_expr = Arc::new(CastColumnExpr::new(
289            Arc::new(column),
290            Arc::new(physical_field.clone()),
291            Arc::new(logical_field.clone()),
292            None,
293        ));
294
295        Ok(Transformed::yes(cast_expr))
296    }
297
298    fn get_partition_value(&self, column_name: &str) -> Option<ScalarValue> {
299        self.partition_fields
300            .iter()
301            .find(|(field, _)| field.name() == column_name)
302            .map(|(_, value)| value.clone())
303    }
304}