Skip to main content

uni_query/query/df_graph/
reduce.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4use std::any::Any;
5use std::fmt::{self, Display, Formatter};
6use std::hash::Hash;
7use std::sync::Arc;
8
9use datafusion::arrow::array::{Array, RecordBatch};
10use datafusion::arrow::compute::cast;
11use datafusion::arrow::datatypes::{DataType, Field, Schema};
12use datafusion::common::Result;
13use datafusion::logical_expr::ColumnarValue;
14use datafusion::physical_plan::PhysicalExpr;
15
16/// Physical expression for Cypher REDUCE: `reduce(acc = init, x IN list | expr)`
17///
18/// Executes reduction by iterating layer-by-layer (vectorized over list index).
19#[derive(Debug, Clone)]
20pub struct ReduceExecExpr {
21    /// Name of the accumulator variable
22    accumulator_name: String,
23    /// Expression for initial value
24    initial_expr: Arc<dyn PhysicalExpr>,
25    /// Name of the loop variable
26    variable_name: String,
27    /// Expression producing the list
28    list_expr: Arc<dyn PhysicalExpr>,
29    /// Reduction expression (update logic)
30    reduce_expr: Arc<dyn PhysicalExpr>,
31    /// Schema of the input batch
32    input_schema: Arc<Schema>,
33    /// Output data type (type of reduce_expr)
34    output_type: DataType,
35}
36
37impl ReduceExecExpr {
38    pub fn new(
39        accumulator_name: String,
40        initial_expr: Arc<dyn PhysicalExpr>,
41        variable_name: String,
42        list_expr: Arc<dyn PhysicalExpr>,
43        reduce_expr: Arc<dyn PhysicalExpr>,
44        input_schema: Arc<Schema>,
45        output_type: DataType,
46    ) -> Self {
47        Self {
48            accumulator_name,
49            initial_expr,
50            variable_name,
51            list_expr,
52            reduce_expr,
53            input_schema,
54            output_type,
55        }
56    }
57}
58
59impl Display for ReduceExecExpr {
60    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
61        write!(
62            f,
63            "reduce({} = {}, {} IN {} | {})",
64            self.accumulator_name,
65            self.initial_expr,
66            self.variable_name,
67            self.list_expr,
68            self.reduce_expr
69        )
70    }
71}
72
73impl PartialEq for ReduceExecExpr {
74    fn eq(&self, other: &Self) -> bool {
75        self.accumulator_name == other.accumulator_name
76            && self.variable_name == other.variable_name
77            && self.output_type == other.output_type
78            && Arc::ptr_eq(&self.initial_expr, &other.initial_expr)
79            && Arc::ptr_eq(&self.list_expr, &other.list_expr)
80            && Arc::ptr_eq(&self.reduce_expr, &other.reduce_expr)
81    }
82}
83
84impl Eq for ReduceExecExpr {}
85
86impl Hash for ReduceExecExpr {
87    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
88        self.accumulator_name.hash(state);
89        self.variable_name.hash(state);
90        self.output_type.hash(state);
91    }
92}
93
94impl PartialEq<dyn Any> for ReduceExecExpr {
95    fn eq(&self, other: &dyn Any) -> bool {
96        other
97            .downcast_ref::<Self>()
98            .map(|x| self == x)
99            .unwrap_or(false)
100    }
101}
102
103impl PhysicalExpr for ReduceExecExpr {
104    fn as_any(&self) -> &dyn Any {
105        self
106    }
107
108    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
109        Ok(self.output_type.clone())
110    }
111
112    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
113        Ok(true)
114    }
115
116    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
117        // 1. Evaluate input list
118        let list_val = self.list_expr.evaluate(batch)?;
119        let list_array = list_val.into_array(batch.num_rows())?;
120
121        // Decode CypherValue-encoded arrays (LargeBinary → LargeList<element_type>)
122        // Use the accumulator type as the target element type since the reduce body
123        // was compiled expecting elements to match the accumulator type.
124        let list_array = if let DataType::LargeBinary = list_array.data_type() {
125            let element_type = self.output_type.clone();
126            crate::query::df_graph::common::cv_array_to_large_list(
127                list_array.as_ref(),
128                &element_type,
129            )?
130        } else {
131            list_array
132        };
133
134        // Normalize to LargeList
135        let list_array = if let DataType::List(field) = list_array.data_type() {
136            let target_type = DataType::LargeList(field.clone());
137            cast(&list_array, &target_type).map_err(|e| {
138                datafusion::error::DataFusionError::Execution(format!("Cast failed: {}", e))
139            })?
140        } else {
141            list_array
142        };
143
144        let large_list = list_array
145            .as_any()
146            .downcast_ref::<datafusion::arrow::array::LargeListArray>()
147            .ok_or_else(|| {
148                datafusion::error::DataFusionError::Execution("Expected LargeListArray".to_string())
149            })?;
150
151        let offsets = large_list.offsets();
152        let values = large_list.values();
153
154        // 2. Evaluate initial value -> current accumulator
155        let init_val = self.initial_expr.evaluate(batch)?;
156        let mut current_acc = init_val.into_array(batch.num_rows())?;
157
158        // 3. Layer-by-layer evaluation
159        // Find max length
160        let mut max_len = 0;
161        for window in offsets.windows(2) {
162            let len = (window[1] - window[0]) as usize;
163            if len > max_len {
164                max_len = len;
165            }
166        }
167
168        for i in 0..max_len {
169            // Identify active rows (list len > i)
170            let mut active_indices_builder =
171                datafusion::arrow::array::UInt32Builder::with_capacity(batch.num_rows());
172            let mut variable_indices_builder =
173                datafusion::arrow::array::UInt32Builder::with_capacity(batch.num_rows());
174
175            for (row_idx, window) in offsets.windows(2).enumerate() {
176                let start = window[0] as usize;
177                let end = window[1] as usize;
178                let len = end - start;
179                if i < len {
180                    active_indices_builder.append_value(row_idx as u32);
181                    variable_indices_builder.append_value((start + i) as u32);
182                }
183            }
184            let active_indices = active_indices_builder.finish();
185            let variable_indices = variable_indices_builder.finish();
186
187            if active_indices.is_empty() {
188                break;
189            }
190
191            // Construct inner batch for active rows
192            // 1. Take outer columns using active_indices
193            let mut inner_columns = Vec::with_capacity(batch.num_columns() + 2);
194            for col in batch.columns() {
195                let taken = datafusion::arrow::compute::take(col, &active_indices, None)?;
196                inner_columns.push(taken);
197            }
198
199            // Construct inner schema with accumulator and variable fields
200            let mut inner_fields = batch.schema().fields().to_vec();
201            let acc_field = Arc::new(Field::new(
202                &self.accumulator_name,
203                current_acc.data_type().clone(),
204                true,
205            ));
206            let var_field = Arc::new(Field::new(
207                &self.variable_name,
208                values.data_type().clone(),
209                true,
210            ));
211
212            // 2. Take accumulator values and replace/append to columns
213            let acc_taken = datafusion::arrow::compute::take(&current_acc, &active_indices, None)?;
214            if let Some(pos) = inner_fields
215                .iter()
216                .position(|f| f.name() == &self.accumulator_name)
217            {
218                inner_columns[pos] = acc_taken;
219                inner_fields[pos] = acc_field;
220            } else {
221                inner_columns.push(acc_taken);
222                inner_fields.push(acc_field);
223            }
224
225            // 3. Take variable values from flattened list values and replace/append to columns
226            let var_taken = datafusion::arrow::compute::take(values, &variable_indices, None)?;
227            if let Some(pos) = inner_fields
228                .iter()
229                .position(|f| f.name() == &self.variable_name)
230            {
231                inner_columns[pos] = var_taken;
232                inner_fields[pos] = var_field;
233            } else {
234                inner_columns.push(var_taken);
235                inner_fields.push(var_field);
236            }
237
238            let inner_schema = Arc::new(Schema::new(inner_fields));
239
240            let inner_batch = RecordBatch::try_new(inner_schema, inner_columns)?;
241
242            // Evaluate reduce expr
243            let new_acc_val = self.reduce_expr.evaluate(&inner_batch)?;
244            let new_acc_array = new_acc_val.into_array(inner_batch.num_rows())?;
245
246            // Scatter updates back to current_acc
247
248            if active_indices.len() == batch.num_rows() {
249                current_acc = new_acc_array;
250            } else {
251                let mut interleave_indices = Vec::with_capacity(batch.num_rows());
252                let mut active_map = vec![None; batch.num_rows()];
253                for (k, &row_idx) in active_indices.values().iter().enumerate() {
254                    active_map[row_idx as usize] = Some(k);
255                }
256
257                for (row_idx, slot) in active_map.iter().enumerate() {
258                    if let Some(k) = slot {
259                        interleave_indices.push((1, *k)); // 1 = new_acc_array
260                    } else {
261                        interleave_indices.push((0, row_idx)); // 0 = current_acc
262                    }
263                }
264
265                current_acc = datafusion::arrow::compute::interleave(
266                    &[&current_acc, &new_acc_array],
267                    &interleave_indices,
268                )?;
269            }
270        }
271
272        Ok(ColumnarValue::Array(current_acc))
273    }
274
275    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
276        // Only expose expressions compiled against the outer schema.
277        // reduce_expr is compiled against an inner schema (with loop variable and accumulator)
278        // and should not be exposed to DataFusion's expression tree traversal.
279        vec![&self.initial_expr, &self.list_expr]
280    }
281
282    fn with_new_children(
283        self: Arc<Self>,
284        children: Vec<Arc<dyn PhysicalExpr>>,
285    ) -> Result<Arc<dyn PhysicalExpr>> {
286        if children.len() != 2 {
287            return Err(datafusion::error::DataFusionError::Internal(
288                "Reduce requires 2 children (initial_expr, list_expr)".to_string(),
289            ));
290        }
291        Ok(Arc::new(Self {
292            initial_expr: children[0].clone(),
293            list_expr: children[1].clone(),
294            reduce_expr: self.reduce_expr.clone(),
295            accumulator_name: self.accumulator_name.clone(),
296            variable_name: self.variable_name.clone(),
297            input_schema: self.input_schema.clone(),
298            output_type: self.output_type.clone(),
299        }))
300    }
301
302    fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303        write!(
304            f,
305            "reduce({} = {}, {} IN {} | {})",
306            self.accumulator_name,
307            self.initial_expr,
308            self.variable_name,
309            self.list_expr,
310            self.reduce_expr
311        )
312    }
313}