Skip to main content

uni_plugin_rhai/
adapter_aggregate.rs

1//! Aggregate function adapter — Rhai-side aggregate via four named fns.
2//!
3//! Rhai aggregates are declared in the manifest as a single `name`. The
4//! script must export four functions following the naming convention:
5//!
6//! - `${name}_init()` — returns the initial state (typically a map).
7//! - `${name}_accumulate(state, x)` — returns the updated state.
8//! - `${name}_merge(state_a, state_b)` — returns the merged state.
9//! - `${name}_finalize(state)` — returns the final aggregate value.
10//!
11//! This four-callable shape avoids the complexity of invoking Rhai
12//! closures stored inside `const` maps; the trade-off is that authors
13//! cannot inline-define an aggregate, but the wiring is straightforward
14//! and survives the script's parse-time check.
15
16#![cfg(feature = "rhai-runtime")]
17
18use std::sync::Arc;
19
20use arrow_array::{ArrayRef, BinaryArray, LargeBinaryArray};
21use arrow_schema::{DataType, Field};
22use datafusion::scalar::ScalarValue;
23use rhai::{Dynamic, Scope};
24use smol_str::SmolStr;
25
26use uni_plugin::errors::FnError;
27use uni_plugin::traits::aggregate::{AggSignature, AggregatePluginFn, PluginAccumulator};
28use uni_plugin::traits::scalar::ArgType;
29
30use crate::dynamic_bridge::array_row_to_dynamic;
31use crate::runtime::RhaiPluginRuntime;
32
33/// Aggregate fn adapter — implements `AggregatePluginFn` by dispatching
34/// to four Rhai callables.
35#[derive(Debug)]
36pub struct RhaiAggregateFn {
37    runtime: Arc<RhaiPluginRuntime>,
38    name: SmolStr,
39    signature: AggSignature,
40}
41
42impl RhaiAggregateFn {
43    /// Construct an aggregate adapter for `name`. The Rhai script must
44    /// export `${name}_init`, `${name}_accumulate`, `${name}_merge`,
45    /// `${name}_finalize`.
46    #[must_use]
47    pub fn new(
48        runtime: Arc<RhaiPluginRuntime>,
49        name: impl Into<SmolStr>,
50        signature: AggSignature,
51    ) -> Self {
52        Self {
53            runtime,
54            name: name.into(),
55            signature,
56        }
57    }
58}
59
60impl AggregatePluginFn for RhaiAggregateFn {
61    fn signature(&self) -> &AggSignature {
62        &self.signature
63    }
64
65    fn create_accumulator(&self) -> Box<dyn PluginAccumulator> {
66        // Initialise state from `${name}_init()`. The previous form used
67        // `.unwrap_or(Dynamic::UNIT)`, which silently substituted UNIT
68        // for any init failure (missing function, panic, type error) and
69        // then corrupted every downstream call. We now capture the init
70        // error and surface it on the first call to any trait method.
71        let mut scope = Scope::new();
72        let init_fn = format!("{}_init", self.name);
73        let (state, init_error) = match self.runtime.engine.call_fn::<Dynamic>(
74            &mut scope,
75            &self.runtime.ast,
76            &init_fn,
77            (),
78        ) {
79            Ok(s) => (s, None),
80            Err(e) => (
81                Dynamic::UNIT,
82                Some(FnError::new(
83                    0x723,
84                    format!("Rhai aggregate `{}` init failed: {e}", self.name),
85                )),
86            ),
87        };
88        Box::new(RhaiAccumulator {
89            runtime: Arc::clone(&self.runtime),
90            name: self.name.clone(),
91            state,
92            input_types: self.signature.args.clone(),
93            init_error,
94        })
95    }
96}
97
98/// Per-group accumulator backed by a `rhai::Dynamic` state value.
99pub struct RhaiAccumulator {
100    runtime: Arc<RhaiPluginRuntime>,
101    name: SmolStr,
102    state: Dynamic,
103    input_types: Vec<ArgType>,
104    /// Set when `${name}_init` failed at construction. Every trait
105    /// method short-circuits with this error so the accumulator can't
106    /// silently produce garbage state.
107    init_error: Option<FnError>,
108}
109
110impl std::fmt::Debug for RhaiAccumulator {
111    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112        f.debug_struct("RhaiAccumulator")
113            .field("name", &self.name)
114            .finish_non_exhaustive()
115    }
116}
117
118impl RhaiAccumulator {
119    /// Surface a cached init failure to any trait method. Cloning the
120    /// `FnError` lets us keep the original so subsequent calls also
121    /// fail (rather than succeeding on the second call once the error
122    /// is taken).
123    fn check_init(&self) -> Result<(), FnError> {
124        match &self.init_error {
125            Some(e) => Err(e.clone()),
126            None => Ok(()),
127        }
128    }
129}
130
131impl PluginAccumulator for RhaiAccumulator {
132    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<(), FnError> {
133        self.check_init()?;
134        let accumulate_fn = format!("{}_accumulate", self.name);
135        let n = values.first().map(|a| a.len()).unwrap_or(0);
136
137        for row in 0..n {
138            let mut dyn_args: Vec<Dynamic> = Vec::with_capacity(values.len() + 1);
139            dyn_args.push(self.state.clone());
140            for (i, arr) in values.iter().enumerate() {
141                let dt = primitive_datatype(&self.input_types, i)?;
142                let d = array_row_to_dynamic(arr, row, &dt)
143                    .map_err(|e| FnError::new(0x12, e.to_string()))?;
144                dyn_args.push(d);
145            }
146            let mut scope = Scope::new();
147            let new_state = self
148                .runtime
149                .engine
150                .call_fn::<Dynamic>(&mut scope, &self.runtime.ast, &accumulate_fn, dyn_args)
151                .map_err(|e| FnError::new(0x720, format!("Rhai accumulate: {e}")))?;
152            self.state = new_state;
153        }
154        Ok(())
155    }
156
157    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<(), FnError> {
158        self.check_init()?;
159        let merge_fn = format!("{}_merge", self.name);
160        let Some(state_arr) = states.first() else {
161            return Ok(());
162        };
163        let n = state_arr.len();
164
165        for row in 0..n {
166            // Decode peer state bytes from a Binary/LargeBinary column.
167            let bytes = peer_state_bytes(state_arr, row)?;
168            let peer_state = decode_state(&bytes)?;
169            let mut scope = Scope::new();
170            let new_state = self
171                .runtime
172                .engine
173                .call_fn::<Dynamic>(
174                    &mut scope,
175                    &self.runtime.ast,
176                    &merge_fn,
177                    (self.state.clone(), peer_state),
178                )
179                .map_err(|e| FnError::new(0x721, format!("Rhai merge: {e}")))?;
180            self.state = new_state;
181        }
182        Ok(())
183    }
184
185    fn state(&self) -> Result<Vec<ScalarValue>, FnError> {
186        self.check_init()?;
187        let bytes = encode_state(&self.state)?;
188        Ok(vec![ScalarValue::LargeBinary(Some(bytes))])
189    }
190
191    fn evaluate(&self) -> Result<ScalarValue, FnError> {
192        self.check_init()?;
193        let finalize_fn = format!("{}_finalize", self.name);
194        let mut scope = Scope::new();
195        let result = self
196            .runtime
197            .engine
198            .call_fn::<Dynamic>(
199                &mut scope,
200                &self.runtime.ast,
201                &finalize_fn,
202                (self.state.clone(),),
203            )
204            .map_err(|e| FnError::new(0x722, format!("Rhai finalize: {e}")))?;
205        dynamic_to_scalar_loose(result)
206    }
207
208    fn size(&self) -> usize {
209        // Conservative estimate. Dynamic doesn't expose memory_use().
210        std::mem::size_of::<Self>() + 64
211    }
212}
213
214fn primitive_datatype(args: &[ArgType], i: usize) -> Result<DataType, FnError> {
215    match args.get(i) {
216        Some(ArgType::Primitive(dt)) => Ok(dt.clone()),
217        Some(other) => Err(FnError::new(
218            0x10,
219            format!("Rhai aggregate arg {i}: primitives only, got {other:?}"),
220        )),
221        None => Err(FnError::new(0x10, format!("missing arg type {i}"))),
222    }
223}
224
225fn peer_state_bytes(arr: &ArrayRef, row: usize) -> Result<Vec<u8>, FnError> {
226    if arr.is_null(row) {
227        return Ok(Vec::new());
228    }
229    if let Some(a) = arr.as_any().downcast_ref::<LargeBinaryArray>() {
230        return Ok(a.value(row).to_vec());
231    }
232    if let Some(a) = arr.as_any().downcast_ref::<BinaryArray>() {
233        return Ok(a.value(row).to_vec());
234    }
235    Err(FnError::new(
236        0x12,
237        format!(
238            "Rhai aggregate merge: expected Binary/LargeBinary state column, got {:?}",
239            arr.data_type()
240        ),
241    ))
242}
243
244fn encode_state(state: &Dynamic) -> Result<Vec<u8>, FnError> {
245    serde_json::to_vec(state).map_err(|e| FnError::new(0x13, format!("Rhai state encode: {e}")))
246}
247
248fn decode_state(bytes: &[u8]) -> Result<Dynamic, FnError> {
249    if bytes.is_empty() {
250        return Ok(Dynamic::UNIT);
251    }
252    let v: serde_json::Value = serde_json::from_slice(bytes)
253        .map_err(|e| FnError::new(0x13, format!("Rhai state decode: {e}")))?;
254    serde_json_to_dynamic(&v).map_err(|e| FnError::new(0x13, format!("Rhai state value: {e}")))
255}
256
257/// Convert a `serde_json::Value` into a `rhai::Dynamic`. Used for
258/// rehydrating peer states during merge.
259pub fn serde_json_to_dynamic(v: &serde_json::Value) -> Result<Dynamic, String> {
260    use serde_json::Value as J;
261    Ok(match v {
262        J::Null => Dynamic::UNIT,
263        J::Bool(b) => Dynamic::from(*b),
264        J::Number(n) => {
265            if let Some(i) = n.as_i64() {
266                Dynamic::from(i)
267            } else if let Some(f) = n.as_f64() {
268                Dynamic::from(f)
269            } else {
270                return Err(format!("unrepresentable number: {n}"));
271            }
272        }
273        J::String(s) => Dynamic::from(s.clone()),
274        J::Array(arr) => {
275            let mut out: rhai::Array = Vec::with_capacity(arr.len());
276            for item in arr {
277                out.push(serde_json_to_dynamic(item)?);
278            }
279            Dynamic::from(out)
280        }
281        J::Object(obj) => {
282            let mut out: rhai::Map = rhai::Map::new();
283            for (k, v) in obj {
284                out.insert(k.as_str().into(), serde_json_to_dynamic(v)?);
285            }
286            Dynamic::from(out)
287        }
288    })
289}
290
291fn dynamic_to_scalar_loose(d: Dynamic) -> Result<ScalarValue, FnError> {
292    if d.is_unit() {
293        return Ok(ScalarValue::Null);
294    }
295    if let Ok(b) = d.as_bool() {
296        return Ok(ScalarValue::Boolean(Some(b)));
297    }
298    if let Ok(i) = d.as_int() {
299        return Ok(ScalarValue::Int64(Some(i)));
300    }
301    if let Ok(f) = d.as_float() {
302        return Ok(ScalarValue::Float64(Some(f)));
303    }
304    if let Ok(s) = d.clone().into_string() {
305        return Ok(ScalarValue::Utf8(Some(s)));
306    }
307    // Fallback: encode as JSON LargeUtf8 for unsupported composite types.
308    let bytes = serde_json::to_string(&d).map_err(|e| FnError::new(0x13, e.to_string()))?;
309    Ok(ScalarValue::LargeUtf8(Some(bytes)))
310}
311
312/// Build the standard state-field schema for a Rhai aggregate. v1
313/// always serializes the Dynamic state as a single LargeBinary column.
314#[must_use]
315pub fn rhai_state_fields() -> Vec<Field> {
316    vec![Field::new("rhai_state", DataType::LargeBinary, true)]
317}
318
319/// Helper to build an `AggSignature` for a Rhai aggregate from wire
320/// strings.
321pub fn build_agg_signature(
322    args: &[String],
323    returns: &str,
324    determinism: &str,
325) -> Result<AggSignature, crate::error::RhaiError> {
326    use crate::wire_translate::{determinism_to_volatility, type_name_to_argtype};
327    let arg_types: Vec<ArgType> = args
328        .iter()
329        .map(|s| type_name_to_argtype(s))
330        .collect::<Result<_, _>>()?;
331    // Return type for aggregates: aggregates often return a map; fall
332    // back to LargeUtf8 when the wire-name maps to nothing we can encode
333    // as a primitive (e.g. "map").
334    let return_type = match returns.trim().to_ascii_lowercase().as_str() {
335        "map" | "object" | "any" => ArgType::Primitive(DataType::LargeUtf8),
336        _ => type_name_to_argtype(returns)?,
337    };
338    Ok(AggSignature {
339        args: arg_types,
340        returns: return_type,
341        state_fields: rhai_state_fields(),
342        volatility: determinism_to_volatility(determinism),
343        supports_partial: true,
344    })
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use crate::engine::build_engine;
351    use crate::host_fns::RhaiHostFnRegistry;
352    use crate::manifest::compile;
353    use arrow_array::Float64Array;
354    use datafusion::logical_expr::Volatility;
355    use uni_plugin::{CapabilitySet, PluginId};
356
357    fn build_runtime(script: &str) -> Arc<RhaiPluginRuntime> {
358        let engine = build_engine(&CapabilitySet::new(), &RhaiHostFnRegistry::new());
359        let ast = compile(&engine, script).unwrap();
360        RhaiPluginRuntime::new(PluginId::new("test.agg"), engine, ast)
361    }
362
363    #[test]
364    fn stats_aggregate_round_trips() {
365        let script = r#"
366            fn stats_init() {
367                #{ n: 0, sum: 0.0, sum_sq: 0.0 }
368            }
369            fn stats_accumulate(state, x) {
370                state.n += 1;
371                state.sum += x;
372                state.sum_sq += x * x;
373                state
374            }
375            fn stats_merge(a, b) {
376                #{ n: a.n + b.n, sum: a.sum + b.sum, sum_sq: a.sum_sq + b.sum_sq }
377            }
378            fn stats_finalize(s) {
379                if s.n == 0 { return (); }
380                s.sum / s.n
381            }
382        "#;
383        let runtime = build_runtime(script);
384        let sig = AggSignature {
385            args: vec![ArgType::Primitive(DataType::Float64)],
386            returns: ArgType::Primitive(DataType::Float64),
387            state_fields: rhai_state_fields(),
388            volatility: Volatility::Immutable,
389            supports_partial: true,
390        };
391        let agg = RhaiAggregateFn::new(runtime, "stats", sig);
392        let mut acc = agg.create_accumulator();
393        let xs: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]));
394        acc.update_batch(&[xs]).unwrap();
395        let result = acc.evaluate().unwrap();
396        match result {
397            ScalarValue::Float64(Some(v)) => assert!((v - 2.5).abs() < 1e-9),
398            other => panic!("unexpected result: {other:?}"),
399        }
400    }
401
402    #[test]
403    fn state_serializes_and_merges() {
404        let script = r#"
405            fn sum_init() { 0.0 }
406            fn sum_accumulate(state, x) { state + x }
407            fn sum_merge(a, b) { a + b }
408            fn sum_finalize(s) { s }
409        "#;
410        let runtime = build_runtime(script);
411        let sig = AggSignature {
412            args: vec![ArgType::Primitive(DataType::Float64)],
413            returns: ArgType::Primitive(DataType::Float64),
414            state_fields: rhai_state_fields(),
415            volatility: Volatility::Immutable,
416            supports_partial: true,
417        };
418        let agg = RhaiAggregateFn::new(runtime, "sum", sig);
419
420        // First partition accumulates [1,2,3]; serializes its state.
421        let mut a = agg.create_accumulator();
422        let xs1: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]));
423        a.update_batch(&[xs1]).unwrap();
424        let state_vec = a.state().unwrap();
425        let state_bytes = match &state_vec[0] {
426            ScalarValue::LargeBinary(Some(b)) => b.clone(),
427            other => panic!("expected LargeBinary, got {other:?}"),
428        };
429
430        // Second partition accumulates [10,20]; merges first's state.
431        let mut b = agg.create_accumulator();
432        let xs2: ArrayRef = Arc::new(Float64Array::from(vec![10.0, 20.0]));
433        b.update_batch(&[xs2]).unwrap();
434        let peer_arr: ArrayRef = Arc::new(LargeBinaryArray::from(vec![state_bytes.as_slice()]));
435        b.merge_batch(&[peer_arr]).unwrap();
436        let result = b.evaluate().unwrap();
437        match result {
438            ScalarValue::Float64(Some(v)) => assert!((v - 36.0).abs() < 1e-9),
439            other => panic!("unexpected result: {other:?}"),
440        }
441    }
442}