vortex_array/expr/analysis/
referenced_field_paths.rs1use vortex_error::VortexResult;
5use vortex_error::vortex_err;
6
7use crate::dtype::DType;
8use crate::dtype::Field;
9use crate::dtype::FieldPath;
10use crate::dtype::FieldPathSet;
11use crate::expr::Expression;
12use crate::expr::traversal::FoldDownContext;
13use crate::expr::traversal::FoldUp;
14use crate::expr::traversal::NodeExt;
15use crate::expr::traversal::NodeFolderContext;
16use crate::scalar_fn::fns::get_item::GetItem;
17use crate::scalar_fn::fns::root::Root;
18use crate::scalar_fn::fns::select::Select;
19
20pub fn referenced_field_paths(expr: &Expression, scope: &DType) -> VortexResult<FieldPathSet> {
28 expr.return_dtype(scope)?;
30
31 let mut collector = ReferencedFieldPaths {
32 scope,
33 field_paths: FieldPathSet::default(),
34 };
35 expr.clone()
36 .fold_context(&vec![FieldPath::root()], &mut collector)?;
37 let field_paths = collector.field_paths;
38
39 #[cfg(debug_assertions)]
43 if let Some(scope_fields) = scope.as_struct_fields_opt() {
44 use vortex_utils::aliases::hash_set::HashSet;
45
46 use crate::dtype::FieldName;
47 use crate::expr::analysis::immediate_access::immediate_scope_access;
48
49 let referenced_heads: HashSet<FieldName> = if field_paths.iter().any(FieldPath::is_root) {
50 scope_fields.names().iter().cloned().collect()
51 } else {
52 field_paths
53 .iter()
54 .filter_map(|path| match path.parts().first() {
55 Some(Field::Name(name)) => Some(name.clone()),
56 _ => None,
57 })
58 .collect()
59 };
60 debug_assert_eq!(
61 referenced_heads,
62 immediate_scope_access(expr, scope_fields),
63 "referenced field path heads must match the immediately accessed scope fields"
64 );
65 }
66
67 Ok(field_paths)
68}
69
70struct ReferencedFieldPaths<'a> {
82 scope: &'a DType,
83 field_paths: FieldPathSet,
84}
85
86impl NodeFolderContext for ReferencedFieldPaths<'_> {
87 type NodeTy = Expression;
88 type Result = ();
89 type Context = Vec<FieldPath>;
90
91 fn visit_down(
92 &mut self,
93 requested: &Self::Context,
94 node: &Expression,
95 ) -> VortexResult<FoldDownContext<Self::Context, ()>> {
96 if node.is::<Root>() {
97 self.field_paths.extend(
98 requested
99 .iter()
100 .map(|path| FieldPath::from_iter(path.parts().iter().rev().cloned())),
101 );
102 return Ok(FoldDownContext::Skip(()));
103 }
104
105 if let Some(field_name) = node.as_opt::<GetItem>() {
106 let appended = requested
107 .iter()
108 .map(|path| path.clone().push(Field::Name(field_name.clone())))
109 .collect();
110 return Ok(FoldDownContext::Continue(appended));
111 }
112
113 if let Some(selection) = node.as_opt::<Select>() {
116 let child_dtype = node.child(0).return_dtype(self.scope)?;
117 let child_fields = child_dtype
118 .as_struct_fields_opt()
119 .ok_or_else(|| vortex_err!("Select child is not a struct"))?;
120 let included_fields = selection.normalize_to_included_fields(child_fields.names())?;
121
122 let mut narrowed = Vec::with_capacity(requested.len());
123 for path in requested {
124 if path.is_root() {
125 narrowed.extend(included_fields.iter().cloned().map(FieldPath::from_name));
126 } else if let Some(Field::Name(field_name)) = path.parts().last()
127 && included_fields
128 .iter()
129 .any(|included| included == field_name)
130 {
131 narrowed.push(path.clone());
132 }
133 }
134
135 if narrowed.is_empty() {
138 return Ok(FoldDownContext::Skip(()));
139 }
140 return Ok(FoldDownContext::Continue(narrowed));
141 }
142
143 Ok(FoldDownContext::Continue(vec![FieldPath::root()]))
145 }
146
147 fn visit_up(
148 &mut self,
149 _node: Expression,
150 _requested: &Self::Context,
151 _children: Vec<()>,
152 ) -> VortexResult<FoldUp<()>> {
153 Ok(FoldUp::Continue(()))
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use vortex_utils::aliases::hash_set::HashSet;
160
161 use super::*;
162 use crate::dtype::Nullability::NonNullable;
163 use crate::dtype::PType::I32;
164 use crate::dtype::StructFields;
165 use crate::expr::get_item;
166 use crate::expr::pack;
167 use crate::expr::root;
168 use crate::expr::select;
169 use crate::expr::select_exclude;
170
171 fn scope() -> DType {
172 DType::Struct(
173 StructFields::from_iter([(
174 "a",
175 DType::Struct(
176 StructFields::from_iter([("x", I32), ("y", I32)]),
177 NonNullable,
178 ),
179 )]),
180 NonNullable,
181 )
182 }
183
184 fn referenced(expr: &Expression) -> VortexResult<HashSet<FieldPath>> {
186 Ok(referenced_field_paths(expr, &scope())?
187 .into_iter()
188 .collect())
189 }
190
191 #[test]
192 fn nested_select_preserves_field_path() -> VortexResult<()> {
193 let expr = select(["x"], get_item("a", root()));
194
195 assert_eq!(
196 referenced(&expr)?,
197 HashSet::from_iter([FieldPath::from_name("a").push("x")])
198 );
199 Ok(())
200 }
201
202 #[test]
203 fn get_item_after_select_only_references_requested_field() -> VortexResult<()> {
204 let expr = get_item("x", select(["x", "y"], get_item("a", root())));
205
206 assert_eq!(
207 referenced(&expr)?,
208 HashSet::from_iter([FieldPath::from_name("a").push("x")])
209 );
210 Ok(())
211 }
212
213 #[test]
214 fn select_exclude_references_included_fields() -> VortexResult<()> {
215 let expr = select_exclude(["y"], get_item("a", root()));
216
217 assert_eq!(
218 referenced(&expr)?,
219 HashSet::from_iter([FieldPath::from_name("a").push("x")])
220 );
221 Ok(())
222 }
223
224 #[test]
225 fn ancestor_path_subsumes_descendant() -> VortexResult<()> {
226 let expr = pack(
227 [
228 ("a", get_item("a", root())),
229 ("x", get_item("x", get_item("a", root()))),
230 ],
231 NonNullable,
232 );
233
234 assert_eq!(
235 referenced(&expr)?,
236 HashSet::from_iter([FieldPath::from_name("a")])
237 );
238 Ok(())
239 }
240
241 #[test]
242 fn get_item_through_opaque_fn_references_all_fields() -> VortexResult<()> {
243 let expr = get_item("x", pack([("x", root())], NonNullable));
246
247 assert_eq!(referenced(&expr)?, HashSet::from_iter([FieldPath::root()]));
248 Ok(())
249 }
250
251 #[test]
252 fn root_references_all_fields() -> VortexResult<()> {
253 assert_eq!(
254 referenced(&root())?,
255 HashSet::from_iter([FieldPath::root()])
256 );
257 Ok(())
258 }
259
260 #[test]
261 fn invalid_get_item_path_returns_error() {
262 assert!(referenced_field_paths(&get_item("missing", root()), &scope()).is_err());
263 }
264}