Skip to main content

vortex_array/arrays/scalar_fn/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_ensure;
10
11use crate::ArrayRef;
12use crate::ArraySlots;
13use crate::array::Array;
14use crate::array::ArrayParts;
15use crate::array::TypedArrayRef;
16use crate::arrays::ScalarFn;
17use crate::scalar_fn::ScalarFnRef;
18
19// ScalarFnArray has a variable number of slots (one per child)
20
21#[derive(Clone, Debug)]
22pub struct ScalarFnData {
23    pub(super) scalar_fn: ScalarFnRef,
24}
25
26impl Display for ScalarFnData {
27    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
28        write!(f, "scalar_fn: {}", self.scalar_fn)
29    }
30}
31
32impl ScalarFnData {
33    /// Create a new ScalarFnArray from a scalar function and its children.
34    pub fn build(
35        scalar_fn: ScalarFnRef,
36        children: Vec<ArrayRef>,
37        len: usize,
38    ) -> VortexResult<Self> {
39        vortex_ensure!(
40            children.iter().all(|c| c.len() == len),
41            "ScalarFnArray must have children equal to the array length"
42        );
43        Ok(Self { scalar_fn })
44    }
45
46    /// Get the scalar function bound to this array.
47    #[inline(always)]
48    pub fn scalar_fn(&self) -> &ScalarFnRef {
49        &self.scalar_fn
50    }
51}
52
53pub trait ScalarFnArrayExt: TypedArrayRef<ScalarFn> {
54    fn scalar_fn(&self) -> &ScalarFnRef {
55        &self.scalar_fn
56    }
57
58    fn child_at(&self, idx: usize) -> &ArrayRef {
59        self.as_ref().slots()[idx]
60            .as_ref()
61            .vortex_expect("ScalarFnArray child slot")
62    }
63
64    fn child_count(&self) -> usize {
65        self.as_ref().slots().len()
66    }
67
68    fn nchildren(&self) -> usize {
69        self.child_count()
70    }
71
72    fn get_child(&self, idx: usize) -> &ArrayRef {
73        self.child_at(idx)
74    }
75
76    fn iter_children(&self) -> impl Iterator<Item = &ArrayRef> + '_ {
77        (0..self.child_count()).map(|idx| self.child_at(idx))
78    }
79
80    fn children(&self) -> Vec<ArrayRef> {
81        self.iter_children().cloned().collect()
82    }
83}
84impl<T: TypedArrayRef<ScalarFn>> ScalarFnArrayExt for T {}
85
86impl Array<ScalarFn> {
87    /// Create a new ScalarFnArray from a scalar function and its children.
88    pub fn try_new(
89        scalar_fn: ScalarFnRef,
90        children: Vec<ArrayRef>,
91        len: usize,
92    ) -> VortexResult<Self> {
93        let arg_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
94        let dtype = scalar_fn.return_dtype(&arg_dtypes)?;
95        let data = ScalarFnData::build(scalar_fn.clone(), children.clone(), len)?;
96        let vtable = ScalarFn { id: scalar_fn.id() };
97        Ok(unsafe {
98            Array::from_parts_unchecked(
99                ArrayParts::new(vtable, dtype, len, data)
100                    .with_slots(children.into_iter().map(Some).collect::<ArraySlots>()),
101            )
102        })
103    }
104}