vortex_array/
normalize.rs1use vortex_error::VortexResult;
5use vortex_error::vortex_bail;
6use vortex_session::registry::Id;
7use vortex_utils::aliases::hash_set::HashSet;
8
9use crate::ArrayRef;
10use crate::ExecutionCtx;
11
12pub struct NormalizeOptions<'a> {
14 pub allowed: &'a HashSet<Id>,
17 pub operation: Operation<'a>,
19}
20
21pub enum Operation<'a> {
23 Error,
24 Execute(&'a mut ExecutionCtx),
25}
26
27impl ArrayRef {
28 pub fn normalize(self, options: &mut NormalizeOptions) -> VortexResult<ArrayRef> {
33 match &mut options.operation {
34 Operation::Error => {
35 self.normalize_with_error(options.allowed)?;
36 Ok(self)
38 }
39 Operation::Execute(ctx) => self.normalize_with_execution(options.allowed, ctx),
40 }
41 }
42
43 fn normalize_with_error(&self, allowed: &HashSet<Id>) -> VortexResult<()> {
44 if !self.is_allowed_encoding(allowed) {
45 vortex_bail!(AssertionFailed: "normalize forbids encoding ({})", self.encoding_id())
46 }
47
48 for child in self.children() {
49 child.normalize_with_error(allowed)?
50 }
51 Ok(())
52 }
53
54 fn normalize_with_execution(
55 self,
56 allowed: &HashSet<Id>,
57 ctx: &mut ExecutionCtx,
58 ) -> VortexResult<ArrayRef> {
59 let mut normalized = self;
60
61 while !normalized.is_allowed_encoding(allowed) {
63 normalized = normalized.execute(ctx)?;
64 }
65
66 let slots = normalized.slots();
68 let mut normalized_slots = Vec::with_capacity(slots.len());
69 let mut any_slot_changed = false;
70
71 for slot in slots {
72 match slot {
73 Some(child) => {
74 let normalized_child = child.clone().normalize(&mut NormalizeOptions {
75 allowed,
76 operation: Operation::Execute(ctx),
77 })?;
78 any_slot_changed |= !ArrayRef::ptr_eq(child, &normalized_child);
79 normalized_slots.push(Some(normalized_child));
80 }
81 None => normalized_slots.push(None),
82 }
83 }
84
85 if any_slot_changed {
86 normalized = normalized.with_slots(normalized_slots)?;
87 }
88
89 Ok(normalized)
90 }
91
92 fn is_allowed_encoding(&self, allowed: &HashSet<Id>) -> bool {
93 allowed.contains(&self.encoding_id()) || self.is_canonical()
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use vortex_error::VortexResult;
100 use vortex_session::VortexSession;
101 use vortex_utils::aliases::hash_set::HashSet;
102
103 use super::NormalizeOptions;
104 use super::Operation;
105 use crate::ArrayRef;
106 use crate::ExecutionCtx;
107 use crate::IntoArray;
108 use crate::arrays::Dict;
109 use crate::arrays::DictArray;
110 use crate::arrays::Primitive;
111 use crate::arrays::PrimitiveArray;
112 use crate::arrays::Slice;
113 use crate::arrays::SliceArray;
114 use crate::arrays::StructArray;
115 use crate::assert_arrays_eq;
116 use crate::validity::Validity;
117
118 #[test]
119 fn normalize_with_execution_keeps_parent_when_children_are_unchanged() -> VortexResult<()> {
120 let field = PrimitiveArray::from_iter(0i32..4).into_array();
121 let array = StructArray::try_new(
122 ["field"].into(),
123 vec![field.clone()],
124 field.len(),
125 Validity::NonNullable,
126 )?
127 .into_array();
128 let allowed = HashSet::from_iter([array.encoding_id(), field.encoding_id()]);
129 let mut ctx = ExecutionCtx::new(VortexSession::empty());
130
131 let normalized = array.clone().normalize(&mut NormalizeOptions {
132 allowed: &allowed,
133 operation: Operation::Execute(&mut ctx),
134 })?;
135
136 assert!(ArrayRef::ptr_eq(&array, &normalized));
137 Ok(())
138 }
139
140 #[test]
141 fn normalize_with_error_allows_canonical_arrays() -> VortexResult<()> {
142 let field = PrimitiveArray::from_iter(0i32..4).into_array();
143 let array = StructArray::try_new(
144 ["field"].into(),
145 vec![field.clone()],
146 field.len(),
147 Validity::NonNullable,
148 )?
149 .into_array();
150 let allowed = HashSet::default();
151
152 let normalized = array.clone().normalize(&mut NormalizeOptions {
153 allowed: &allowed,
154 operation: Operation::Error,
155 })?;
156
157 assert!(ArrayRef::ptr_eq(&array, &normalized));
158 Ok(())
159 }
160
161 #[test]
162 fn normalize_with_execution_rebuilds_parent_when_a_child_changes() -> VortexResult<()> {
163 let unchanged = PrimitiveArray::from_iter(0i32..4).into_array();
164 let sliced =
165 SliceArray::new(PrimitiveArray::from_iter(10i32..20).into_array(), 2..6).into_array();
166 let array = StructArray::try_new(
167 ["lhs", "rhs"].into(),
168 vec![unchanged.clone(), sliced],
169 unchanged.len(),
170 Validity::NonNullable,
171 )?
172 .into_array();
173 let allowed = HashSet::from_iter([array.encoding_id(), unchanged.encoding_id()]);
174 let mut ctx = ExecutionCtx::new(VortexSession::empty());
175
176 let normalized = array.clone().normalize(&mut NormalizeOptions {
177 allowed: &allowed,
178 operation: Operation::Execute(&mut ctx),
179 })?;
180
181 assert!(!ArrayRef::ptr_eq(&array, &normalized));
182
183 let original_children = array.children();
184 let normalized_children = normalized.children();
185 assert!(ArrayRef::ptr_eq(
186 &original_children[0],
187 &normalized_children[0]
188 ));
189 assert!(!ArrayRef::ptr_eq(
190 &original_children[1],
191 &normalized_children[1]
192 ));
193 assert_arrays_eq!(normalized_children[1], PrimitiveArray::from_iter(12i32..16));
194
195 Ok(())
196 }
197
198 #[test]
199 fn normalize_slice_of_dict_returns_dict() -> VortexResult<()> {
200 let codes = PrimitiveArray::from_iter(vec![0u32, 1, 0, 1, 2]).into_array();
201 let values = PrimitiveArray::from_iter(vec![10i32, 20, 30]).into_array();
202 let dict = DictArray::try_new(codes, values)?.into_array();
203
204 let sliced = SliceArray::new(dict, 1..4).into_array();
206 assert_eq!(sliced.encoding_id(), Slice::ID);
207
208 let allowed = HashSet::from_iter([Dict::ID, Primitive::ID]);
209 let mut ctx = ExecutionCtx::new(VortexSession::empty());
210
211 println!("sliced {}", sliced.display_tree());
212
213 let normalized = sliced.normalize(&mut NormalizeOptions {
214 allowed: &allowed,
215 operation: Operation::Execute(&mut ctx),
216 })?;
217
218 println!("after {}", normalized.display_tree());
219
220 assert_eq!(normalized.encoding_id(), Dict::ID);
222 assert_eq!(normalized.len(), 3);
223
224 assert_arrays_eq!(
226 normalized.to_canonical()?,
227 PrimitiveArray::from_iter(vec![20i32, 10, 20])
228 );
229
230 Ok(())
231 }
232}