Skip to main content

vortex_array/arrays/scalar_fn/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_ensure;
10
11use crate::ArrayRef;
12use crate::array::Array;
13use crate::array::ArrayParts;
14use crate::array::TypedArrayRef;
15use crate::arrays::ScalarFnVTable;
16use crate::scalar_fn::ScalarFnRef;
17
18// ScalarFnArray has a variable number of slots (one per child)
19
20#[derive(Clone, Debug)]
21pub struct ScalarFnData {
22    pub(super) scalar_fn: ScalarFnRef,
23}
24
25impl Display for ScalarFnData {
26    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
27        write!(f, "scalar_fn: {}", self.scalar_fn)
28    }
29}
30
31impl ScalarFnData {
32    /// Create a new ScalarFnArray from a scalar function and its children.
33    pub fn build(
34        scalar_fn: ScalarFnRef,
35        children: Vec<ArrayRef>,
36        len: usize,
37    ) -> VortexResult<Self> {
38        vortex_ensure!(
39            children.iter().all(|c| c.len() == len),
40            "ScalarFnArray must have children equal to the array length"
41        );
42        Ok(Self { scalar_fn })
43    }
44
45    /// Get the scalar function bound to this array.
46    #[allow(clippy::same_name_method)]
47    #[inline(always)]
48    pub fn scalar_fn(&self) -> &ScalarFnRef {
49        &self.scalar_fn
50    }
51}
52
53pub trait ScalarFnArrayExt: TypedArrayRef<ScalarFnVTable> {
54    fn scalar_fn(&self) -> &ScalarFnRef {
55        &self.scalar_fn
56    }
57
58    fn child_at(&self, idx: usize) -> &ArrayRef {
59        self.as_ref().slots()[idx]
60            .as_ref()
61            .vortex_expect("ScalarFnArray child slot")
62    }
63
64    fn child_count(&self) -> usize {
65        self.as_ref().slots().len()
66    }
67
68    #[allow(clippy::same_name_method)]
69    fn nchildren(&self) -> usize {
70        self.child_count()
71    }
72
73    #[allow(clippy::same_name_method)]
74    fn get_child(&self, idx: usize) -> &ArrayRef {
75        self.child_at(idx)
76    }
77
78    fn iter_children(&self) -> impl Iterator<Item = &ArrayRef> + '_ {
79        (0..self.child_count()).map(|idx| self.child_at(idx))
80    }
81
82    fn children(&self) -> Vec<ArrayRef> {
83        self.iter_children().cloned().collect()
84    }
85}
86impl<T: TypedArrayRef<ScalarFnVTable>> ScalarFnArrayExt for T {}
87
88impl Array<ScalarFnVTable> {
89    /// Create a new ScalarFnArray from a scalar function and its children.
90    pub fn try_new(
91        scalar_fn: ScalarFnRef,
92        children: Vec<ArrayRef>,
93        len: usize,
94    ) -> VortexResult<Self> {
95        let arg_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
96        let dtype = scalar_fn.return_dtype(&arg_dtypes)?;
97        let data = ScalarFnData::build(scalar_fn.clone(), children.clone(), len)?;
98        let vtable = ScalarFnVTable { scalar_fn };
99        Ok(unsafe {
100            Array::from_parts_unchecked(
101                ArrayParts::new(vtable, dtype, len, data)
102                    .with_slots(children.into_iter().map(Some).collect()),
103            )
104        })
105    }
106}