uni_query/query/df_graph/
reduce.rs1use std::any::Any;
5use std::fmt::{self, Display, Formatter};
6use std::hash::Hash;
7use std::sync::Arc;
8
9use datafusion::arrow::array::{Array, RecordBatch};
10use datafusion::arrow::compute::cast;
11use datafusion::arrow::datatypes::{DataType, Field, Schema};
12use datafusion::common::Result;
13use datafusion::logical_expr::ColumnarValue;
14use datafusion::physical_plan::PhysicalExpr;
15
16#[derive(Debug, Clone)]
20pub struct ReduceExecExpr {
21 accumulator_name: String,
23 initial_expr: Arc<dyn PhysicalExpr>,
25 variable_name: String,
27 list_expr: Arc<dyn PhysicalExpr>,
29 reduce_expr: Arc<dyn PhysicalExpr>,
31 input_schema: Arc<Schema>,
33 output_type: DataType,
35}
36
37impl ReduceExecExpr {
38 pub fn new(
39 accumulator_name: String,
40 initial_expr: Arc<dyn PhysicalExpr>,
41 variable_name: String,
42 list_expr: Arc<dyn PhysicalExpr>,
43 reduce_expr: Arc<dyn PhysicalExpr>,
44 input_schema: Arc<Schema>,
45 output_type: DataType,
46 ) -> Self {
47 Self {
48 accumulator_name,
49 initial_expr,
50 variable_name,
51 list_expr,
52 reduce_expr,
53 input_schema,
54 output_type,
55 }
56 }
57}
58
59impl Display for ReduceExecExpr {
60 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
61 write!(
62 f,
63 "reduce({} = {}, {} IN {} | {})",
64 self.accumulator_name,
65 self.initial_expr,
66 self.variable_name,
67 self.list_expr,
68 self.reduce_expr
69 )
70 }
71}
72
73impl PartialEq for ReduceExecExpr {
74 fn eq(&self, other: &Self) -> bool {
75 self.accumulator_name == other.accumulator_name
76 && self.variable_name == other.variable_name
77 && self.output_type == other.output_type
78 && Arc::ptr_eq(&self.initial_expr, &other.initial_expr)
79 && Arc::ptr_eq(&self.list_expr, &other.list_expr)
80 && Arc::ptr_eq(&self.reduce_expr, &other.reduce_expr)
81 }
82}
83
84impl Eq for ReduceExecExpr {}
85
86impl Hash for ReduceExecExpr {
87 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
88 self.accumulator_name.hash(state);
89 self.variable_name.hash(state);
90 self.output_type.hash(state);
91 }
92}
93
94impl PartialEq<dyn Any> for ReduceExecExpr {
95 fn eq(&self, other: &dyn Any) -> bool {
96 other
97 .downcast_ref::<Self>()
98 .map(|x| self == x)
99 .unwrap_or(false)
100 }
101}
102
103impl PhysicalExpr for ReduceExecExpr {
104 fn as_any(&self) -> &dyn Any {
105 self
106 }
107
108 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
109 Ok(self.output_type.clone())
110 }
111
112 fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
113 Ok(true)
114 }
115
116 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
117 let list_val = self.list_expr.evaluate(batch)?;
119 let list_array = list_val.into_array(batch.num_rows())?;
120
121 let list_array = if let DataType::LargeBinary = list_array.data_type() {
125 let element_type = self.output_type.clone();
126 crate::query::df_graph::common::cv_array_to_large_list(
127 list_array.as_ref(),
128 &element_type,
129 )?
130 } else {
131 list_array
132 };
133
134 let list_array = if let DataType::List(field) = list_array.data_type() {
136 let target_type = DataType::LargeList(field.clone());
137 cast(&list_array, &target_type).map_err(|e| {
138 datafusion::error::DataFusionError::Execution(format!("Cast failed: {}", e))
139 })?
140 } else {
141 list_array
142 };
143
144 let large_list = list_array
145 .as_any()
146 .downcast_ref::<datafusion::arrow::array::LargeListArray>()
147 .ok_or_else(|| {
148 datafusion::error::DataFusionError::Execution("Expected LargeListArray".to_string())
149 })?;
150
151 let offsets = large_list.offsets();
152 let values = large_list.values();
153
154 let init_val = self.initial_expr.evaluate(batch)?;
156 let mut current_acc = init_val.into_array(batch.num_rows())?;
157
158 let mut max_len = 0;
161 for window in offsets.windows(2) {
162 let len = (window[1] - window[0]) as usize;
163 if len > max_len {
164 max_len = len;
165 }
166 }
167
168 for i in 0..max_len {
169 let mut active_indices_builder =
171 datafusion::arrow::array::UInt32Builder::with_capacity(batch.num_rows());
172 let mut variable_indices_builder =
173 datafusion::arrow::array::UInt32Builder::with_capacity(batch.num_rows());
174
175 for (row_idx, window) in offsets.windows(2).enumerate() {
176 let start = window[0] as usize;
177 let end = window[1] as usize;
178 let len = end - start;
179 if i < len {
180 active_indices_builder.append_value(row_idx as u32);
181 variable_indices_builder.append_value((start + i) as u32);
182 }
183 }
184 let active_indices = active_indices_builder.finish();
185 let variable_indices = variable_indices_builder.finish();
186
187 if active_indices.is_empty() {
188 break;
189 }
190
191 let mut inner_columns = Vec::with_capacity(batch.num_columns() + 2);
194 for col in batch.columns() {
195 let taken = datafusion::arrow::compute::take(col, &active_indices, None)?;
196 inner_columns.push(taken);
197 }
198
199 let mut inner_fields = batch.schema().fields().to_vec();
201 let acc_field = Arc::new(Field::new(
202 &self.accumulator_name,
203 current_acc.data_type().clone(),
204 true,
205 ));
206 let var_field = Arc::new(Field::new(
207 &self.variable_name,
208 values.data_type().clone(),
209 true,
210 ));
211
212 let acc_taken = datafusion::arrow::compute::take(¤t_acc, &active_indices, None)?;
214 if let Some(pos) = inner_fields
215 .iter()
216 .position(|f| f.name() == &self.accumulator_name)
217 {
218 inner_columns[pos] = acc_taken;
219 inner_fields[pos] = acc_field;
220 } else {
221 inner_columns.push(acc_taken);
222 inner_fields.push(acc_field);
223 }
224
225 let var_taken = datafusion::arrow::compute::take(values, &variable_indices, None)?;
227 if let Some(pos) = inner_fields
228 .iter()
229 .position(|f| f.name() == &self.variable_name)
230 {
231 inner_columns[pos] = var_taken;
232 inner_fields[pos] = var_field;
233 } else {
234 inner_columns.push(var_taken);
235 inner_fields.push(var_field);
236 }
237
238 let inner_schema = Arc::new(Schema::new(inner_fields));
239
240 let inner_batch = RecordBatch::try_new(inner_schema, inner_columns)?;
241
242 let new_acc_val = self.reduce_expr.evaluate(&inner_batch)?;
244 let new_acc_array = new_acc_val.into_array(inner_batch.num_rows())?;
245
246 if active_indices.len() == batch.num_rows() {
249 current_acc = new_acc_array;
250 } else {
251 let mut interleave_indices = Vec::with_capacity(batch.num_rows());
252 let mut active_map = vec![None; batch.num_rows()];
253 for (k, &row_idx) in active_indices.values().iter().enumerate() {
254 active_map[row_idx as usize] = Some(k);
255 }
256
257 for (row_idx, slot) in active_map.iter().enumerate() {
258 if let Some(k) = slot {
259 interleave_indices.push((1, *k)); } else {
261 interleave_indices.push((0, row_idx)); }
263 }
264
265 current_acc = datafusion::arrow::compute::interleave(
266 &[¤t_acc, &new_acc_array],
267 &interleave_indices,
268 )?;
269 }
270 }
271
272 Ok(ColumnarValue::Array(current_acc))
273 }
274
275 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
276 vec![&self.initial_expr, &self.list_expr]
280 }
281
282 fn with_new_children(
283 self: Arc<Self>,
284 children: Vec<Arc<dyn PhysicalExpr>>,
285 ) -> Result<Arc<dyn PhysicalExpr>> {
286 if children.len() != 2 {
287 return Err(datafusion::error::DataFusionError::Internal(
288 "Reduce requires 2 children (initial_expr, list_expr)".to_string(),
289 ));
290 }
291 Ok(Arc::new(Self {
292 initial_expr: children[0].clone(),
293 list_expr: children[1].clone(),
294 reduce_expr: self.reduce_expr.clone(),
295 accumulator_name: self.accumulator_name.clone(),
296 variable_name: self.variable_name.clone(),
297 input_schema: self.input_schema.clone(),
298 output_type: self.output_type.clone(),
299 }))
300 }
301
302 fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303 write!(
304 f,
305 "reduce({} = {}, {} IN {} | {})",
306 self.accumulator_name,
307 self.initial_expr,
308 self.variable_name,
309 self.list_expr,
310 self.reduce_expr
311 )
312 }
313}