Skip to main content

vortex_array/
normalize.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_error::vortex_bail;
6use vortex_session::registry::Id;
7use vortex_utils::aliases::hash_set::HashSet;
8
9use crate::ArrayRef;
10use crate::ExecutionCtx;
11
12/// Options for normalizing an array.
13pub struct NormalizeOptions<'a> {
14    /// The set of allowed array encodings (in addition to the canonical ones) that are permitted
15    /// in the normalized array.
16    pub allowed: &'a HashSet<Id>,
17    /// The operation to perform when a non-allowed encoding is encountered.
18    pub operation: Operation<'a>,
19}
20
21/// The operation to perform when a non-allowed encoding is encountered.
22pub enum Operation<'a> {
23    Error,
24    Execute(&'a mut ExecutionCtx),
25}
26
27impl ArrayRef {
28    /// Normalize the array according to given options.
29    ///
30    /// This operation performs a recursive traversal of the array. Any non-allowed encoding is
31    /// normalized per the configured operation.
32    pub fn normalize(self, options: &mut NormalizeOptions) -> VortexResult<ArrayRef> {
33        match &mut options.operation {
34            Operation::Error => {
35                self.normalize_with_error(options.allowed)?;
36                // Note this takes ownership so we can at a later date remove non-allowed encodings.
37                Ok(self)
38            }
39            Operation::Execute(ctx) => self.normalize_with_execution(options.allowed, ctx),
40        }
41    }
42
43    fn normalize_with_error(&self, allowed: &HashSet<Id>) -> VortexResult<()> {
44        if !self.is_allowed_encoding(allowed) {
45            vortex_bail!(AssertionFailed: "normalize forbids encoding ({})", self.encoding_id())
46        }
47
48        for child in self.children() {
49            child.normalize_with_error(allowed)?
50        }
51        Ok(())
52    }
53
54    fn normalize_with_execution(
55        self,
56        allowed: &HashSet<Id>,
57        ctx: &mut ExecutionCtx,
58    ) -> VortexResult<ArrayRef> {
59        let mut normalized = self;
60
61        // Top-first execute the array tree while we hit non-allowed encodings.
62        while !normalized.is_allowed_encoding(allowed) {
63            normalized = normalized.execute(ctx)?;
64        }
65
66        // Now we've normalized the root, we need to ensure the children are normalized also.
67        let slots = normalized.slots();
68        let mut normalized_slots = Vec::with_capacity(slots.len());
69        let mut any_slot_changed = false;
70
71        for slot in slots {
72            match slot {
73                Some(child) => {
74                    let normalized_child = child.clone().normalize(&mut NormalizeOptions {
75                        allowed,
76                        operation: Operation::Execute(ctx),
77                    })?;
78                    any_slot_changed |= !ArrayRef::ptr_eq(child, &normalized_child);
79                    normalized_slots.push(Some(normalized_child));
80                }
81                None => normalized_slots.push(None),
82            }
83        }
84
85        if any_slot_changed {
86            normalized = normalized.with_slots(normalized_slots)?;
87        }
88
89        Ok(normalized)
90    }
91
92    fn is_allowed_encoding(&self, allowed: &HashSet<Id>) -> bool {
93        allowed.contains(&self.encoding_id()) || self.is_canonical()
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use vortex_error::VortexResult;
100    use vortex_session::VortexSession;
101    use vortex_utils::aliases::hash_set::HashSet;
102
103    use super::NormalizeOptions;
104    use super::Operation;
105    use crate::ArrayRef;
106    use crate::ExecutionCtx;
107    use crate::IntoArray;
108    use crate::arrays::Dict;
109    use crate::arrays::DictArray;
110    use crate::arrays::Primitive;
111    use crate::arrays::PrimitiveArray;
112    use crate::arrays::Slice;
113    use crate::arrays::SliceArray;
114    use crate::arrays::StructArray;
115    use crate::assert_arrays_eq;
116    use crate::validity::Validity;
117
118    #[test]
119    fn normalize_with_execution_keeps_parent_when_children_are_unchanged() -> VortexResult<()> {
120        let field = PrimitiveArray::from_iter(0i32..4).into_array();
121        let array = StructArray::try_new(
122            ["field"].into(),
123            vec![field.clone()],
124            field.len(),
125            Validity::NonNullable,
126        )?
127        .into_array();
128        let allowed = HashSet::from_iter([array.encoding_id(), field.encoding_id()]);
129        let mut ctx = ExecutionCtx::new(VortexSession::empty());
130
131        let normalized = array.clone().normalize(&mut NormalizeOptions {
132            allowed: &allowed,
133            operation: Operation::Execute(&mut ctx),
134        })?;
135
136        assert!(ArrayRef::ptr_eq(&array, &normalized));
137        Ok(())
138    }
139
140    #[test]
141    fn normalize_with_error_allows_canonical_arrays() -> VortexResult<()> {
142        let field = PrimitiveArray::from_iter(0i32..4).into_array();
143        let array = StructArray::try_new(
144            ["field"].into(),
145            vec![field.clone()],
146            field.len(),
147            Validity::NonNullable,
148        )?
149        .into_array();
150        let allowed = HashSet::default();
151
152        let normalized = array.clone().normalize(&mut NormalizeOptions {
153            allowed: &allowed,
154            operation: Operation::Error,
155        })?;
156
157        assert!(ArrayRef::ptr_eq(&array, &normalized));
158        Ok(())
159    }
160
161    #[test]
162    fn normalize_with_execution_rebuilds_parent_when_a_child_changes() -> VortexResult<()> {
163        let unchanged = PrimitiveArray::from_iter(0i32..4).into_array();
164        let sliced =
165            SliceArray::new(PrimitiveArray::from_iter(10i32..20).into_array(), 2..6).into_array();
166        let array = StructArray::try_new(
167            ["lhs", "rhs"].into(),
168            vec![unchanged.clone(), sliced],
169            unchanged.len(),
170            Validity::NonNullable,
171        )?
172        .into_array();
173        let allowed = HashSet::from_iter([array.encoding_id(), unchanged.encoding_id()]);
174        let mut ctx = ExecutionCtx::new(VortexSession::empty());
175
176        let normalized = array.clone().normalize(&mut NormalizeOptions {
177            allowed: &allowed,
178            operation: Operation::Execute(&mut ctx),
179        })?;
180
181        assert!(!ArrayRef::ptr_eq(&array, &normalized));
182
183        let original_children = array.children();
184        let normalized_children = normalized.children();
185        assert!(ArrayRef::ptr_eq(
186            &original_children[0],
187            &normalized_children[0]
188        ));
189        assert!(!ArrayRef::ptr_eq(
190            &original_children[1],
191            &normalized_children[1]
192        ));
193        assert_arrays_eq!(normalized_children[1], PrimitiveArray::from_iter(12i32..16));
194
195        Ok(())
196    }
197
198    #[test]
199    fn normalize_slice_of_dict_returns_dict() -> VortexResult<()> {
200        let codes = PrimitiveArray::from_iter(vec![0u32, 1, 0, 1, 2]).into_array();
201        let values = PrimitiveArray::from_iter(vec![10i32, 20, 30]).into_array();
202        let dict = DictArray::try_new(codes, values)?.into_array();
203
204        // Slice the dict array to get a SliceArray wrapping a DictArray.
205        let sliced = SliceArray::new(dict, 1..4).into_array();
206        assert_eq!(sliced.encoding_id(), Slice::ID);
207
208        let allowed = HashSet::from_iter([Dict::ID, Primitive::ID]);
209        let mut ctx = ExecutionCtx::new(VortexSession::empty());
210
211        println!("sliced {}", sliced.display_tree());
212
213        let normalized = sliced.normalize(&mut NormalizeOptions {
214            allowed: &allowed,
215            operation: Operation::Execute(&mut ctx),
216        })?;
217
218        println!("after {}", normalized.display_tree());
219
220        // The normalized result should be a DictArray, not a SliceArray.
221        assert_eq!(normalized.encoding_id(), Dict::ID);
222        assert_eq!(normalized.len(), 3);
223
224        // Verify the data: codes [1,0,1] -> values [20, 10, 20]
225        assert_arrays_eq!(
226            normalized.to_canonical()?,
227            PrimitiveArray::from_iter(vec![20i32, 10, 20])
228        );
229
230        Ok(())
231    }
232}