vortex_array/arrays/scalar_fn/
array.rs1use std::fmt::Display;
5use std::fmt::Formatter;
6
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_ensure;
10
11use crate::ArrayRef;
12use crate::array::Array;
13use crate::array::ArrayParts;
14use crate::array::TypedArrayRef;
15use crate::arrays::ScalarFnVTable;
16use crate::scalar_fn::ScalarFnRef;
17
18#[derive(Clone, Debug)]
21pub struct ScalarFnData {
22 pub(super) scalar_fn: ScalarFnRef,
23}
24
25impl Display for ScalarFnData {
26 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
27 write!(f, "scalar_fn: {}", self.scalar_fn)
28 }
29}
30
31impl ScalarFnData {
32 pub fn build(
34 scalar_fn: ScalarFnRef,
35 children: Vec<ArrayRef>,
36 len: usize,
37 ) -> VortexResult<Self> {
38 vortex_ensure!(
39 children.iter().all(|c| c.len() == len),
40 "ScalarFnArray must have children equal to the array length"
41 );
42 Ok(Self { scalar_fn })
43 }
44
45 #[inline(always)]
47 pub fn scalar_fn(&self) -> &ScalarFnRef {
48 &self.scalar_fn
49 }
50}
51
52pub trait ScalarFnArrayExt: TypedArrayRef<ScalarFnVTable> {
53 fn scalar_fn(&self) -> &ScalarFnRef {
54 &self.scalar_fn
55 }
56
57 fn child_at(&self, idx: usize) -> &ArrayRef {
58 self.as_ref().slots()[idx]
59 .as_ref()
60 .vortex_expect("ScalarFnArray child slot")
61 }
62
63 fn child_count(&self) -> usize {
64 self.as_ref().slots().len()
65 }
66
67 fn nchildren(&self) -> usize {
68 self.child_count()
69 }
70
71 fn get_child(&self, idx: usize) -> &ArrayRef {
72 self.child_at(idx)
73 }
74
75 fn iter_children(&self) -> impl Iterator<Item = &ArrayRef> + '_ {
76 (0..self.child_count()).map(|idx| self.child_at(idx))
77 }
78
79 fn children(&self) -> Vec<ArrayRef> {
80 self.iter_children().cloned().collect()
81 }
82}
83impl<T: TypedArrayRef<ScalarFnVTable>> ScalarFnArrayExt for T {}
84
85impl Array<ScalarFnVTable> {
86 pub fn try_new(
88 scalar_fn: ScalarFnRef,
89 children: Vec<ArrayRef>,
90 len: usize,
91 ) -> VortexResult<Self> {
92 let arg_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
93 let dtype = scalar_fn.return_dtype(&arg_dtypes)?;
94 let data = ScalarFnData::build(scalar_fn.clone(), children.clone(), len)?;
95 let vtable = ScalarFnVTable { id: scalar_fn.id() };
96 Ok(unsafe {
97 Array::from_parts_unchecked(
98 ArrayParts::new(vtable, dtype, len, data)
99 .with_slots(children.into_iter().map(Some).collect()),
100 )
101 })
102 }
103}