Skip to main content

vortex_array/
normalize.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use smallvec::SmallVec;
5use vortex_error::VortexResult;
6use vortex_error::vortex_bail;
7use vortex_session::registry::Id;
8use vortex_utils::aliases::hash_set::HashSet;
9
10use crate::ArrayRef;
11use crate::ExecutionCtx;
12
13/// Options for normalizing an array.
14pub struct NormalizeOptions<'a> {
15    /// The set of allowed array encodings (in addition to the canonical ones) that are permitted
16    /// in the normalized array.
17    pub allowed: &'a HashSet<Id>,
18    /// The operation to perform when a non-allowed encoding is encountered.
19    pub operation: Operation<'a>,
20}
21
22/// The operation to perform when a non-allowed encoding is encountered.
23pub enum Operation<'a> {
24    Error,
25    Execute(&'a mut ExecutionCtx),
26}
27
28impl ArrayRef {
29    /// Normalize the array according to given options.
30    ///
31    /// This operation performs a recursive traversal of the array. Any non-allowed encoding is
32    /// normalized per the configured operation.
33    pub fn normalize(self, options: &mut NormalizeOptions) -> VortexResult<ArrayRef> {
34        match &mut options.operation {
35            Operation::Error => {
36                self.normalize_with_error(options.allowed)?;
37                // Note this takes ownership so we can at a later date remove non-allowed encodings.
38                Ok(self)
39            }
40            Operation::Execute(ctx) => self.normalize_with_execution(options.allowed, ctx),
41        }
42    }
43
44    fn normalize_with_error(&self, allowed: &HashSet<Id>) -> VortexResult<()> {
45        if !self.is_allowed_encoding(allowed) {
46            vortex_bail!(AssertionFailed: "normalize forbids encoding ({})", self.encoding_id())
47        }
48
49        for child in self.children() {
50            child.normalize_with_error(allowed)?
51        }
52        Ok(())
53    }
54
55    fn normalize_with_execution(
56        self,
57        allowed: &HashSet<Id>,
58        ctx: &mut ExecutionCtx,
59    ) -> VortexResult<ArrayRef> {
60        let mut normalized = self;
61
62        // Top-first execute the array tree while we hit non-allowed encodings.
63        while !normalized.is_allowed_encoding(allowed) {
64            normalized = normalized.execute(ctx)?;
65        }
66
67        // Now we've normalized the root, we need to ensure the children are normalized also.
68        let slots = normalized.slots();
69        let mut normalized_slots = SmallVec::with_capacity(slots.len());
70        let mut any_slot_changed = false;
71
72        for slot in slots {
73            match slot {
74                Some(child) => {
75                    let normalized_child = child.clone().normalize(&mut NormalizeOptions {
76                        allowed,
77                        operation: Operation::Execute(ctx),
78                    })?;
79                    any_slot_changed |= !ArrayRef::ptr_eq(child, &normalized_child);
80                    normalized_slots.push(Some(normalized_child));
81                }
82                None => normalized_slots.push(None),
83            }
84        }
85
86        if any_slot_changed {
87            // SAFETY: normalization only rewrites child slots to logically equivalent allowed
88            // encodings, preserving parent logical values and statistics.
89            normalized = unsafe { normalized.with_slots(normalized_slots) }?;
90        }
91
92        Ok(normalized)
93    }
94
95    fn is_allowed_encoding(&self, allowed: &HashSet<Id>) -> bool {
96        allowed.contains(&self.encoding_id()) || self.is_canonical()
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use vortex_error::VortexResult;
103    use vortex_utils::aliases::hash_set::HashSet;
104
105    use super::NormalizeOptions;
106    use super::Operation;
107    use crate::ArrayRef;
108    use crate::IntoArray;
109    use crate::VortexSessionExecute;
110    use crate::array::VTable;
111    use crate::arrays::Dict;
112    use crate::arrays::DictArray;
113    use crate::arrays::Primitive;
114    use crate::arrays::PrimitiveArray;
115    use crate::arrays::Slice;
116    use crate::arrays::SliceArray;
117    use crate::arrays::StructArray;
118    use crate::assert_arrays_eq;
119    use crate::validity::Validity;
120
121    #[test]
122    fn normalize_with_execution_keeps_parent_when_children_are_unchanged() -> VortexResult<()> {
123        let field = PrimitiveArray::from_iter(0i32..4).into_array();
124        let array = StructArray::try_new(
125            ["field"].into(),
126            vec![field.clone()],
127            field.len(),
128            Validity::NonNullable,
129        )?
130        .into_array();
131        let allowed = HashSet::from_iter([array.encoding_id(), field.encoding_id()]);
132        let mut ctx = crate::array_session().create_execution_ctx();
133
134        let normalized = array.clone().normalize(&mut NormalizeOptions {
135            allowed: &allowed,
136            operation: Operation::Execute(&mut ctx),
137        })?;
138
139        assert!(ArrayRef::ptr_eq(&array, &normalized));
140        Ok(())
141    }
142
143    #[test]
144    fn normalize_with_error_allows_canonical_arrays() -> VortexResult<()> {
145        let field = PrimitiveArray::from_iter(0i32..4).into_array();
146        let array = StructArray::try_new(
147            ["field"].into(),
148            vec![field.clone()],
149            field.len(),
150            Validity::NonNullable,
151        )?
152        .into_array();
153        let allowed = HashSet::default();
154
155        let normalized = array.clone().normalize(&mut NormalizeOptions {
156            allowed: &allowed,
157            operation: Operation::Error,
158        })?;
159
160        assert!(ArrayRef::ptr_eq(&array, &normalized));
161        Ok(())
162    }
163
164    #[test]
165    fn normalize_with_execution_rebuilds_parent_when_a_child_changes() -> VortexResult<()> {
166        let unchanged = PrimitiveArray::from_iter(0i32..4).into_array();
167        let sliced =
168            SliceArray::new(PrimitiveArray::from_iter(10i32..20).into_array(), 2..6).into_array();
169        let array = StructArray::try_new(
170            ["lhs", "rhs"].into(),
171            vec![unchanged.clone(), sliced],
172            unchanged.len(),
173            Validity::NonNullable,
174        )?
175        .into_array();
176        let allowed = HashSet::from_iter([array.encoding_id(), unchanged.encoding_id()]);
177        let mut ctx = crate::array_session().create_execution_ctx();
178
179        let normalized = array.clone().normalize(&mut NormalizeOptions {
180            allowed: &allowed,
181            operation: Operation::Execute(&mut ctx),
182        })?;
183
184        assert!(!ArrayRef::ptr_eq(&array, &normalized));
185
186        let original_children = array.children();
187        let normalized_children = normalized.children();
188        assert!(ArrayRef::ptr_eq(
189            &original_children[0],
190            &normalized_children[0]
191        ));
192        assert!(!ArrayRef::ptr_eq(
193            &original_children[1],
194            &normalized_children[1]
195        ));
196        assert_arrays_eq!(
197            normalized_children[1],
198            PrimitiveArray::from_iter(12i32..16),
199            &mut ctx
200        );
201
202        Ok(())
203    }
204
205    #[test]
206    fn normalize_slice_of_dict_returns_dict() -> VortexResult<()> {
207        let codes = PrimitiveArray::from_iter(vec![0u32, 1, 0, 1, 2]).into_array();
208        let values = PrimitiveArray::from_iter(vec![10i32, 20, 30]).into_array();
209        let dict = DictArray::try_new(codes, values)?.into_array();
210
211        // Slice the dict array to get a SliceArray wrapping a DictArray.
212        let sliced = SliceArray::new(dict, 1..4).into_array();
213        assert_eq!(sliced.encoding_id(), Slice.id());
214
215        let allowed = HashSet::from_iter([Dict.id(), Primitive.id()]);
216        let mut ctx = crate::array_session().create_execution_ctx();
217
218        let normalized = sliced.normalize(&mut NormalizeOptions {
219            allowed: &allowed,
220            operation: Operation::Execute(&mut ctx),
221        })?;
222
223        // The normalized result should be a DictArray, not a SliceArray.
224        assert_eq!(normalized.encoding_id(), Dict.id());
225        assert_eq!(normalized.len(), 3);
226
227        // Verify the data: codes [1,0,1] -> values [20, 10, 20]
228        #[expect(deprecated)]
229        let normalized_canonical = normalized.to_canonical()?;
230        assert_arrays_eq!(
231            normalized_canonical,
232            PrimitiveArray::from_iter(vec![20i32, 10, 20]),
233            &mut ctx
234        );
235
236        Ok(())
237    }
238}