Skip to main content

vortex_array/arrays/scalar_fn/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_ensure;
10
11use crate::ArrayRef;
12use crate::array::Array;
13use crate::array::ArrayParts;
14use crate::array::TypedArrayRef;
15use crate::arrays::ScalarFnVTable;
16use crate::scalar_fn::ScalarFnRef;
17
18// ScalarFnArray has a variable number of slots (one per child)
19
20#[derive(Clone, Debug)]
21pub struct ScalarFnData {
22    pub(super) scalar_fn: ScalarFnRef,
23}
24
25impl Display for ScalarFnData {
26    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
27        write!(f, "scalar_fn: {}", self.scalar_fn)
28    }
29}
30
31impl ScalarFnData {
32    /// Create a new ScalarFnArray from a scalar function and its children.
33    pub fn build(
34        scalar_fn: ScalarFnRef,
35        children: Vec<ArrayRef>,
36        len: usize,
37    ) -> VortexResult<Self> {
38        vortex_ensure!(
39            children.iter().all(|c| c.len() == len),
40            "ScalarFnArray must have children equal to the array length"
41        );
42        Ok(Self { scalar_fn })
43    }
44
45    /// Get the scalar function bound to this array.
46    #[inline(always)]
47    pub fn scalar_fn(&self) -> &ScalarFnRef {
48        &self.scalar_fn
49    }
50}
51
52pub trait ScalarFnArrayExt: TypedArrayRef<ScalarFnVTable> {
53    fn scalar_fn(&self) -> &ScalarFnRef {
54        &self.scalar_fn
55    }
56
57    fn child_at(&self, idx: usize) -> &ArrayRef {
58        self.as_ref().slots()[idx]
59            .as_ref()
60            .vortex_expect("ScalarFnArray child slot")
61    }
62
63    fn child_count(&self) -> usize {
64        self.as_ref().slots().len()
65    }
66
67    fn nchildren(&self) -> usize {
68        self.child_count()
69    }
70
71    fn get_child(&self, idx: usize) -> &ArrayRef {
72        self.child_at(idx)
73    }
74
75    fn iter_children(&self) -> impl Iterator<Item = &ArrayRef> + '_ {
76        (0..self.child_count()).map(|idx| self.child_at(idx))
77    }
78
79    fn children(&self) -> Vec<ArrayRef> {
80        self.iter_children().cloned().collect()
81    }
82}
83impl<T: TypedArrayRef<ScalarFnVTable>> ScalarFnArrayExt for T {}
84
85impl Array<ScalarFnVTable> {
86    /// Create a new ScalarFnArray from a scalar function and its children.
87    pub fn try_new(
88        scalar_fn: ScalarFnRef,
89        children: Vec<ArrayRef>,
90        len: usize,
91    ) -> VortexResult<Self> {
92        let arg_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
93        let dtype = scalar_fn.return_dtype(&arg_dtypes)?;
94        let data = ScalarFnData::build(scalar_fn.clone(), children.clone(), len)?;
95        let vtable = ScalarFnVTable { id: scalar_fn.id() };
96        Ok(unsafe {
97            Array::from_parts_unchecked(
98                ArrayParts::new(vtable, dtype, len, data)
99                    .with_slots(children.into_iter().map(Some).collect()),
100            )
101        })
102    }
103}