vortex_array/scalar_fn/internal/
row_count.rs1use std::fmt::Formatter;
5
6use vortex_array::ArrayRef;
7use vortex_array::ExecutionCtx;
8use vortex_array::arrays::ScalarFn;
9use vortex_array::arrays::scalar_fn::ExactScalarFn;
10use vortex_array::arrays::scalar_fn::ScalarFnArrayExt;
11use vortex_array::dtype::DType;
12use vortex_array::dtype::Nullability;
13use vortex_array::dtype::PType;
14use vortex_array::expr::Expression;
15use vortex_array::scalar_fn::Arity;
16use vortex_array::scalar_fn::ChildName;
17use vortex_array::scalar_fn::EmptyOptions;
18use vortex_array::scalar_fn::ExecutionArgs;
19use vortex_array::scalar_fn::ScalarFnId;
20use vortex_array::scalar_fn::ScalarFnVTable;
21use vortex_error::VortexResult;
22use vortex_error::vortex_bail;
23use vortex_error::vortex_ensure;
24use vortex_session::registry::CachedId;
25
26#[derive(Clone)]
46pub struct RowCount;
47
48impl ScalarFnVTable for RowCount {
49 type Options = EmptyOptions;
50
51 fn id(&self) -> ScalarFnId {
52 static ID: CachedId = CachedId::new("vortex.row_count");
53 *ID
54 }
55
56 fn arity(&self, _options: &Self::Options) -> Arity {
57 Arity::Exact(0)
58 }
59
60 fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
61 unreachable!("RowCount has arity 0")
62 }
63
64 fn fmt_sql(
65 &self,
66 _options: &Self::Options,
67 _expr: &Expression,
68 f: &mut Formatter<'_>,
69 ) -> std::fmt::Result {
70 write!(f, "row_count()")
71 }
72
73 fn return_dtype(&self, _options: &Self::Options, _args: &[DType]) -> VortexResult<DType> {
74 Ok(DType::Primitive(PType::U64, Nullability::NonNullable))
75 }
76
77 fn execute(
78 &self,
79 _options: &Self::Options,
80 _args: &dyn ExecutionArgs,
81 _ctx: &mut ExecutionCtx,
82 ) -> VortexResult<ArrayRef> {
83 vortex_bail!("RowCount must be substituted before evaluation")
84 }
85
86 fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
87 false
88 }
89
90 fn is_fallible(&self, _options: &Self::Options) -> bool {
91 false
92 }
93}
94
95pub fn contains_row_count(array: &ArrayRef) -> bool {
103 if array.is::<ExactScalarFn<RowCount>>() {
104 return true;
105 }
106 match array.as_opt::<ScalarFn>() {
107 Some(view) => view.iter_children().any(contains_row_count),
108 None => false,
109 }
110}
111
112pub fn substitute_row_count(array: ArrayRef, replacement: &ArrayRef) -> VortexResult<ArrayRef> {
121 if array.is::<ExactScalarFn<RowCount>>() {
122 vortex_ensure!(
123 replacement.len() == array.len(),
124 "RowCount replacement length {} does not match scope length {}",
125 replacement.len(),
126 array.len(),
127 );
128 vortex_ensure!(
129 replacement.dtype() == array.dtype(),
130 "RowCount replacement dtype {} does not match scope dtype {}",
131 replacement.dtype(),
132 array.dtype(),
133 );
134 return Ok(replacement.clone());
135 }
136
137 if !array.is::<ScalarFn>() {
138 return Ok(array);
139 }
140
141 let nchildren = array.nchildren();
142 let mut array = array;
143 for slot_idx in 0..nchildren {
144 let (taken, child) = unsafe { array.take_slot_unchecked(slot_idx)? };
149 let new_child = substitute_row_count(child, replacement)?;
150 array = unsafe { taken.put_slot_unchecked(slot_idx, new_child)? };
151 }
152 Ok(array)
153}
154
155#[cfg(test)]
156mod tests {
157 use vortex_array::dtype::DType;
158 use vortex_array::dtype::Nullability;
159 use vortex_array::dtype::PType;
160
161 use crate::scalar_fn::EmptyOptions;
162 use crate::scalar_fn::internal::row_count::RowCount;
163 use crate::scalar_fn::vtable::ScalarFnVTableExt;
164
165 #[test]
166 fn row_count_helper_dtype() {
167 let expr = RowCount.new_expr(EmptyOptions, []);
168 assert_eq!(
169 expr.return_dtype(&DType::Primitive(PType::I32, Nullability::Nullable))
170 .unwrap(),
171 DType::Primitive(PType::U64, Nullability::NonNullable),
172 );
173 }
174}