Skip to main content

vortex_array/arrays/scalar_fn/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_ensure;
11
12use crate::ArrayRef;
13use crate::ArraySlots;
14use crate::array::Array;
15use crate::array::ArrayParts;
16use crate::array::TypedArrayRef;
17use crate::arrays::ScalarFn;
18use crate::scalar_fn::ScalarFnRef;
19
20// ScalarFnArray has a variable number of slots (one per child)
21
22#[derive(Clone, Debug)]
23pub struct ScalarFnData {
24    pub(super) scalar_fn: ScalarFnRef,
25}
26
27impl Display for ScalarFnData {
28    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
29        write!(f, "scalar_fn: {}", self.scalar_fn)
30    }
31}
32
33impl ScalarFnData {
34    /// Get the scalar function bound to this array.
35    #[inline(always)]
36    pub fn scalar_fn(&self) -> &ScalarFnRef {
37        &self.scalar_fn
38    }
39}
40
41pub trait ScalarFnArrayExt: TypedArrayRef<ScalarFn> {
42    fn scalar_fn(&self) -> &ScalarFnRef {
43        &self.scalar_fn
44    }
45
46    fn child_at(&self, idx: usize) -> &ArrayRef {
47        self.as_ref().slots()[idx]
48            .as_ref()
49            .vortex_expect("ScalarFnArray child slot")
50    }
51
52    fn child_count(&self) -> usize {
53        self.as_ref().slots().len()
54    }
55
56    fn nchildren(&self) -> usize {
57        self.child_count()
58    }
59
60    fn get_child(&self, idx: usize) -> &ArrayRef {
61        self.child_at(idx)
62    }
63
64    fn iter_children(&self) -> impl Iterator<Item = &ArrayRef> + '_ {
65        (0..self.child_count()).map(|idx| self.child_at(idx))
66    }
67
68    fn children(&self) -> Vec<ArrayRef> {
69        self.iter_children().cloned().collect()
70    }
71}
72impl<T: TypedArrayRef<ScalarFn>> ScalarFnArrayExt for T {}
73
74impl Array<ScalarFn> {
75    /// Create a new ScalarFnArray from a scalar function and its children.
76    pub fn try_new(scalar_fn: ScalarFnRef, children: Vec<ArrayRef>) -> VortexResult<Self> {
77        let len = Self::infer_len(&children)?;
78        Self::try_new_with_len(scalar_fn, children, len)
79    }
80
81    /// Create a new ScalarFnArray from a scalar function, children, and an explicit length.
82    ///
83    /// This is needed for zero-child scalar functions and deserialization paths where there is no
84    /// child array to infer the length from.
85    pub fn try_new_with_len(
86        scalar_fn: ScalarFnRef,
87        children: Vec<ArrayRef>,
88        len: usize,
89    ) -> VortexResult<Self> {
90        Self::validate_children_len(&children, len)?;
91        let arg_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
92        let dtype = scalar_fn.return_dtype(&arg_dtypes)?;
93        let data = ScalarFnData {
94            scalar_fn: scalar_fn.clone(),
95        };
96        let vtable = ScalarFn { id: scalar_fn.id() };
97        Ok(unsafe {
98            Array::from_parts_unchecked(
99                ArrayParts::new(vtable, dtype, len, data)
100                    .with_slots(children.into_iter().map(Some).collect::<ArraySlots>()),
101            )
102        })
103    }
104
105    fn infer_len(children: &[ArrayRef]) -> VortexResult<usize> {
106        let Some(child) = children.first() else {
107            vortex_bail!("ScalarFnArray length cannot be inferred without children");
108        };
109        Ok(child.len())
110    }
111
112    fn validate_children_len(children: &[ArrayRef], len: usize) -> VortexResult<()> {
113        vortex_ensure!(
114            children.iter().all(|c| c.len() == len),
115            "ScalarFnArray must have children equal to the array length"
116        );
117        Ok(())
118    }
119}