vortex_array/arrays/scalar_fn/
array.rs1use std::fmt::Display;
5use std::fmt::Formatter;
6
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_ensure;
11
12use crate::ArrayRef;
13use crate::ArraySlots;
14use crate::array::Array;
15use crate::array::ArrayParts;
16use crate::array::TypedArrayRef;
17use crate::arrays::ScalarFn;
18use crate::scalar_fn::ScalarFnRef;
19
20#[derive(Clone, Debug)]
23pub struct ScalarFnData {
24 pub(super) scalar_fn: ScalarFnRef,
25}
26
27impl Display for ScalarFnData {
28 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
29 write!(f, "scalar_fn: {}", self.scalar_fn)
30 }
31}
32
33impl ScalarFnData {
34 #[inline(always)]
36 pub fn scalar_fn(&self) -> &ScalarFnRef {
37 &self.scalar_fn
38 }
39}
40
41pub trait ScalarFnArrayExt: TypedArrayRef<ScalarFn> {
42 fn scalar_fn(&self) -> &ScalarFnRef {
43 &self.scalar_fn
44 }
45
46 fn child_at(&self, idx: usize) -> &ArrayRef {
47 self.as_ref().slots()[idx]
48 .as_ref()
49 .vortex_expect("ScalarFnArray child slot")
50 }
51
52 fn child_count(&self) -> usize {
53 self.as_ref().slots().len()
54 }
55
56 fn nchildren(&self) -> usize {
57 self.child_count()
58 }
59
60 fn get_child(&self, idx: usize) -> &ArrayRef {
61 self.child_at(idx)
62 }
63
64 fn iter_children(&self) -> impl Iterator<Item = &ArrayRef> + '_ {
65 (0..self.child_count()).map(|idx| self.child_at(idx))
66 }
67
68 fn children(&self) -> Vec<ArrayRef> {
69 self.iter_children().cloned().collect()
70 }
71}
72impl<T: TypedArrayRef<ScalarFn>> ScalarFnArrayExt for T {}
73
74impl Array<ScalarFn> {
75 pub fn try_new(scalar_fn: ScalarFnRef, children: Vec<ArrayRef>) -> VortexResult<Self> {
77 let len = Self::infer_len(&children)?;
78 Self::try_new_with_len(scalar_fn, children, len)
79 }
80
81 pub fn try_new_with_len(
86 scalar_fn: ScalarFnRef,
87 children: Vec<ArrayRef>,
88 len: usize,
89 ) -> VortexResult<Self> {
90 Self::validate_children_len(&children, len)?;
91 let arg_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
92 let dtype = scalar_fn.return_dtype(&arg_dtypes)?;
93 let data = ScalarFnData {
94 scalar_fn: scalar_fn.clone(),
95 };
96 let vtable = ScalarFn { id: scalar_fn.id() };
97 Ok(unsafe {
98 Array::from_parts_unchecked(
99 ArrayParts::new(vtable, dtype, len, data)
100 .with_slots(children.into_iter().map(Some).collect::<ArraySlots>()),
101 )
102 })
103 }
104
105 fn infer_len(children: &[ArrayRef]) -> VortexResult<usize> {
106 let Some(child) = children.first() else {
107 vortex_bail!("ScalarFnArray length cannot be inferred without children");
108 };
109 Ok(child.len())
110 }
111
112 fn validate_children_len(children: &[ArrayRef], len: usize) -> VortexResult<()> {
113 vortex_ensure!(
114 children.iter().all(|c| c.len() == len),
115 "ScalarFnArray must have children equal to the array length"
116 );
117 Ok(())
118 }
119}