vortex_array/arrays/scalar_fn/
array.rs1use std::fmt::Display;
5use std::fmt::Formatter;
6
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_ensure;
10
11use crate::ArrayRef;
12use crate::ArraySlots;
13use crate::array::Array;
14use crate::array::ArrayParts;
15use crate::array::TypedArrayRef;
16use crate::arrays::ScalarFn;
17use crate::scalar_fn::ScalarFnRef;
18
19#[derive(Clone, Debug)]
22pub struct ScalarFnData {
23 pub(super) scalar_fn: ScalarFnRef,
24}
25
26impl Display for ScalarFnData {
27 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
28 write!(f, "scalar_fn: {}", self.scalar_fn)
29 }
30}
31
32impl ScalarFnData {
33 pub fn build(
35 scalar_fn: ScalarFnRef,
36 children: Vec<ArrayRef>,
37 len: usize,
38 ) -> VortexResult<Self> {
39 vortex_ensure!(
40 children.iter().all(|c| c.len() == len),
41 "ScalarFnArray must have children equal to the array length"
42 );
43 Ok(Self { scalar_fn })
44 }
45
46 #[inline(always)]
48 pub fn scalar_fn(&self) -> &ScalarFnRef {
49 &self.scalar_fn
50 }
51}
52
53pub trait ScalarFnArrayExt: TypedArrayRef<ScalarFn> {
54 fn scalar_fn(&self) -> &ScalarFnRef {
55 &self.scalar_fn
56 }
57
58 fn child_at(&self, idx: usize) -> &ArrayRef {
59 self.as_ref().slots()[idx]
60 .as_ref()
61 .vortex_expect("ScalarFnArray child slot")
62 }
63
64 fn child_count(&self) -> usize {
65 self.as_ref().slots().len()
66 }
67
68 fn nchildren(&self) -> usize {
69 self.child_count()
70 }
71
72 fn get_child(&self, idx: usize) -> &ArrayRef {
73 self.child_at(idx)
74 }
75
76 fn iter_children(&self) -> impl Iterator<Item = &ArrayRef> + '_ {
77 (0..self.child_count()).map(|idx| self.child_at(idx))
78 }
79
80 fn children(&self) -> Vec<ArrayRef> {
81 self.iter_children().cloned().collect()
82 }
83}
84impl<T: TypedArrayRef<ScalarFn>> ScalarFnArrayExt for T {}
85
86impl Array<ScalarFn> {
87 pub fn try_new(
89 scalar_fn: ScalarFnRef,
90 children: Vec<ArrayRef>,
91 len: usize,
92 ) -> VortexResult<Self> {
93 let arg_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
94 let dtype = scalar_fn.return_dtype(&arg_dtypes)?;
95 let data = ScalarFnData::build(scalar_fn.clone(), children.clone(), len)?;
96 let vtable = ScalarFn { id: scalar_fn.id() };
97 Ok(unsafe {
98 Array::from_parts_unchecked(
99 ArrayParts::new(vtable, dtype, len, data)
100 .with_slots(children.into_iter().map(Some).collect::<ArraySlots>()),
101 )
102 })
103 }
104}