vortex_array/scalar_fn/internal/
row_count.rs1use std::fmt::Formatter;
5
6use vortex_array::ArrayRef;
7use vortex_array::ExecutionCtx;
8use vortex_array::arrays::ScalarFn;
9use vortex_array::arrays::scalar_fn::ExactScalarFn;
10use vortex_array::arrays::scalar_fn::ScalarFnArrayExt;
11use vortex_array::dtype::DType;
12use vortex_array::dtype::Nullability;
13use vortex_array::dtype::PType;
14use vortex_array::expr::Expression;
15use vortex_array::scalar_fn::Arity;
16use vortex_array::scalar_fn::ChildName;
17use vortex_array::scalar_fn::EmptyOptions;
18use vortex_array::scalar_fn::ExecutionArgs;
19use vortex_array::scalar_fn::ScalarFnId;
20use vortex_array::scalar_fn::ScalarFnVTable;
21use vortex_error::VortexResult;
22use vortex_error::vortex_bail;
23use vortex_error::vortex_ensure;
24
25#[derive(Clone)]
45pub struct RowCount;
46
47impl ScalarFnVTable for RowCount {
48 type Options = EmptyOptions;
49
50 fn id(&self) -> ScalarFnId {
51 ScalarFnId::from("vortex.row_count")
52 }
53
54 fn arity(&self, _options: &Self::Options) -> Arity {
55 Arity::Exact(0)
56 }
57
58 fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
59 unreachable!("RowCount has arity 0")
60 }
61
62 fn fmt_sql(
63 &self,
64 _options: &Self::Options,
65 _expr: &Expression,
66 f: &mut Formatter<'_>,
67 ) -> std::fmt::Result {
68 write!(f, "row_count()")
69 }
70
71 fn return_dtype(&self, _options: &Self::Options, _args: &[DType]) -> VortexResult<DType> {
72 Ok(DType::Primitive(PType::U64, Nullability::NonNullable))
73 }
74
75 fn execute(
76 &self,
77 _options: &Self::Options,
78 _args: &dyn ExecutionArgs,
79 _ctx: &mut ExecutionCtx,
80 ) -> VortexResult<ArrayRef> {
81 vortex_bail!("RowCount must be substituted before evaluation")
82 }
83
84 fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
85 false
86 }
87
88 fn is_fallible(&self, _options: &Self::Options) -> bool {
89 false
90 }
91}
92
93pub fn contains_row_count(array: &ArrayRef) -> bool {
101 if array.is::<ExactScalarFn<RowCount>>() {
102 return true;
103 }
104 match array.as_opt::<ScalarFn>() {
105 Some(view) => view.iter_children().any(contains_row_count),
106 None => false,
107 }
108}
109
110pub fn substitute_row_count(array: ArrayRef, replacement: &ArrayRef) -> VortexResult<ArrayRef> {
119 if array.is::<ExactScalarFn<RowCount>>() {
120 vortex_ensure!(
121 replacement.len() == array.len(),
122 "RowCount replacement length {} does not match scope length {}",
123 replacement.len(),
124 array.len(),
125 );
126 vortex_ensure!(
127 replacement.dtype() == array.dtype(),
128 "RowCount replacement dtype {} does not match scope dtype {}",
129 replacement.dtype(),
130 array.dtype(),
131 );
132 return Ok(replacement.clone());
133 }
134
135 if !array.is::<ScalarFn>() {
136 return Ok(array);
137 }
138
139 let nchildren = array.nchildren();
140 let mut array = array;
141 for slot_idx in 0..nchildren {
142 let (taken, child) = unsafe { array.take_slot_unchecked(slot_idx)? };
147 let new_child = substitute_row_count(child, replacement)?;
148 array = unsafe { taken.put_slot_unchecked(slot_idx, new_child)? };
149 }
150 Ok(array)
151}
152
153#[cfg(test)]
154mod tests {
155 use vortex_array::dtype::DType;
156 use vortex_array::dtype::Nullability;
157 use vortex_array::dtype::PType;
158
159 use crate::scalar_fn::EmptyOptions;
160 use crate::scalar_fn::internal::row_count::RowCount;
161 use crate::scalar_fn::vtable::ScalarFnVTableExt;
162
163 #[test]
164 fn row_count_helper_dtype() {
165 let expr = RowCount.new_expr(EmptyOptions, []);
166 assert_eq!(
167 expr.return_dtype(&DType::Primitive(PType::I32, Nullability::Nullable))
168 .unwrap(),
169 DType::Primitive(PType::U64, Nullability::NonNullable),
170 );
171 }
172}