1#![cfg(feature = "rhai-runtime")]
17
18use std::sync::Arc;
19
20use arrow_array::{ArrayRef, BinaryArray, LargeBinaryArray};
21use arrow_schema::{DataType, Field};
22use datafusion::scalar::ScalarValue;
23use rhai::{Dynamic, Scope};
24use smol_str::SmolStr;
25
26use uni_plugin::errors::FnError;
27use uni_plugin::traits::aggregate::{AggSignature, AggregatePluginFn, PluginAccumulator};
28use uni_plugin::traits::scalar::ArgType;
29
30use crate::dynamic_bridge::array_row_to_dynamic;
31use crate::runtime::RhaiPluginRuntime;
32
33#[derive(Debug)]
36pub struct RhaiAggregateFn {
37 runtime: Arc<RhaiPluginRuntime>,
38 name: SmolStr,
39 signature: AggSignature,
40}
41
42impl RhaiAggregateFn {
43 #[must_use]
47 pub fn new(
48 runtime: Arc<RhaiPluginRuntime>,
49 name: impl Into<SmolStr>,
50 signature: AggSignature,
51 ) -> Self {
52 Self {
53 runtime,
54 name: name.into(),
55 signature,
56 }
57 }
58}
59
60impl AggregatePluginFn for RhaiAggregateFn {
61 fn signature(&self) -> &AggSignature {
62 &self.signature
63 }
64
65 fn create_accumulator(&self) -> Box<dyn PluginAccumulator> {
66 let mut scope = Scope::new();
72 let init_fn = format!("{}_init", self.name);
73 let (state, init_error) = match self.runtime.engine.call_fn::<Dynamic>(
74 &mut scope,
75 &self.runtime.ast,
76 &init_fn,
77 (),
78 ) {
79 Ok(s) => (s, None),
80 Err(e) => (
81 Dynamic::UNIT,
82 Some(FnError::new(
83 0x723,
84 format!("Rhai aggregate `{}` init failed: {e}", self.name),
85 )),
86 ),
87 };
88 Box::new(RhaiAccumulator {
89 runtime: Arc::clone(&self.runtime),
90 name: self.name.clone(),
91 state,
92 input_types: self.signature.args.clone(),
93 init_error,
94 })
95 }
96}
97
98pub struct RhaiAccumulator {
100 runtime: Arc<RhaiPluginRuntime>,
101 name: SmolStr,
102 state: Dynamic,
103 input_types: Vec<ArgType>,
104 init_error: Option<FnError>,
108}
109
110impl std::fmt::Debug for RhaiAccumulator {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 f.debug_struct("RhaiAccumulator")
113 .field("name", &self.name)
114 .finish_non_exhaustive()
115 }
116}
117
118impl RhaiAccumulator {
119 fn check_init(&self) -> Result<(), FnError> {
124 match &self.init_error {
125 Some(e) => Err(e.clone()),
126 None => Ok(()),
127 }
128 }
129}
130
131impl PluginAccumulator for RhaiAccumulator {
132 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<(), FnError> {
133 self.check_init()?;
134 let accumulate_fn = format!("{}_accumulate", self.name);
135 let n = values.first().map(|a| a.len()).unwrap_or(0);
136
137 for row in 0..n {
138 let mut dyn_args: Vec<Dynamic> = Vec::with_capacity(values.len() + 1);
139 dyn_args.push(self.state.clone());
140 for (i, arr) in values.iter().enumerate() {
141 let dt = primitive_datatype(&self.input_types, i)?;
142 let d = array_row_to_dynamic(arr, row, &dt)
143 .map_err(|e| FnError::new(0x12, e.to_string()))?;
144 dyn_args.push(d);
145 }
146 let mut scope = Scope::new();
147 let new_state = self
148 .runtime
149 .engine
150 .call_fn::<Dynamic>(&mut scope, &self.runtime.ast, &accumulate_fn, dyn_args)
151 .map_err(|e| FnError::new(0x720, format!("Rhai accumulate: {e}")))?;
152 self.state = new_state;
153 }
154 Ok(())
155 }
156
157 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<(), FnError> {
158 self.check_init()?;
159 let merge_fn = format!("{}_merge", self.name);
160 let Some(state_arr) = states.first() else {
161 return Ok(());
162 };
163 let n = state_arr.len();
164
165 for row in 0..n {
166 let bytes = peer_state_bytes(state_arr, row)?;
168 let peer_state = decode_state(&bytes)?;
169 let mut scope = Scope::new();
170 let new_state = self
171 .runtime
172 .engine
173 .call_fn::<Dynamic>(
174 &mut scope,
175 &self.runtime.ast,
176 &merge_fn,
177 (self.state.clone(), peer_state),
178 )
179 .map_err(|e| FnError::new(0x721, format!("Rhai merge: {e}")))?;
180 self.state = new_state;
181 }
182 Ok(())
183 }
184
185 fn state(&self) -> Result<Vec<ScalarValue>, FnError> {
186 self.check_init()?;
187 let bytes = encode_state(&self.state)?;
188 Ok(vec![ScalarValue::LargeBinary(Some(bytes))])
189 }
190
191 fn evaluate(&self) -> Result<ScalarValue, FnError> {
192 self.check_init()?;
193 let finalize_fn = format!("{}_finalize", self.name);
194 let mut scope = Scope::new();
195 let result = self
196 .runtime
197 .engine
198 .call_fn::<Dynamic>(
199 &mut scope,
200 &self.runtime.ast,
201 &finalize_fn,
202 (self.state.clone(),),
203 )
204 .map_err(|e| FnError::new(0x722, format!("Rhai finalize: {e}")))?;
205 dynamic_to_scalar_loose(result)
206 }
207
208 fn size(&self) -> usize {
209 std::mem::size_of::<Self>() + 64
211 }
212}
213
214fn primitive_datatype(args: &[ArgType], i: usize) -> Result<DataType, FnError> {
215 match args.get(i) {
216 Some(ArgType::Primitive(dt)) => Ok(dt.clone()),
217 Some(other) => Err(FnError::new(
218 0x10,
219 format!("Rhai aggregate arg {i}: primitives only, got {other:?}"),
220 )),
221 None => Err(FnError::new(0x10, format!("missing arg type {i}"))),
222 }
223}
224
225fn peer_state_bytes(arr: &ArrayRef, row: usize) -> Result<Vec<u8>, FnError> {
226 if arr.is_null(row) {
227 return Ok(Vec::new());
228 }
229 if let Some(a) = arr.as_any().downcast_ref::<LargeBinaryArray>() {
230 return Ok(a.value(row).to_vec());
231 }
232 if let Some(a) = arr.as_any().downcast_ref::<BinaryArray>() {
233 return Ok(a.value(row).to_vec());
234 }
235 Err(FnError::new(
236 0x12,
237 format!(
238 "Rhai aggregate merge: expected Binary/LargeBinary state column, got {:?}",
239 arr.data_type()
240 ),
241 ))
242}
243
244fn encode_state(state: &Dynamic) -> Result<Vec<u8>, FnError> {
245 serde_json::to_vec(state).map_err(|e| FnError::new(0x13, format!("Rhai state encode: {e}")))
246}
247
248fn decode_state(bytes: &[u8]) -> Result<Dynamic, FnError> {
249 if bytes.is_empty() {
250 return Ok(Dynamic::UNIT);
251 }
252 let v: serde_json::Value = serde_json::from_slice(bytes)
253 .map_err(|e| FnError::new(0x13, format!("Rhai state decode: {e}")))?;
254 serde_json_to_dynamic(&v).map_err(|e| FnError::new(0x13, format!("Rhai state value: {e}")))
255}
256
257pub fn serde_json_to_dynamic(v: &serde_json::Value) -> Result<Dynamic, String> {
260 use serde_json::Value as J;
261 Ok(match v {
262 J::Null => Dynamic::UNIT,
263 J::Bool(b) => Dynamic::from(*b),
264 J::Number(n) => {
265 if let Some(i) = n.as_i64() {
266 Dynamic::from(i)
267 } else if let Some(f) = n.as_f64() {
268 Dynamic::from(f)
269 } else {
270 return Err(format!("unrepresentable number: {n}"));
271 }
272 }
273 J::String(s) => Dynamic::from(s.clone()),
274 J::Array(arr) => {
275 let mut out: rhai::Array = Vec::with_capacity(arr.len());
276 for item in arr {
277 out.push(serde_json_to_dynamic(item)?);
278 }
279 Dynamic::from(out)
280 }
281 J::Object(obj) => {
282 let mut out: rhai::Map = rhai::Map::new();
283 for (k, v) in obj {
284 out.insert(k.as_str().into(), serde_json_to_dynamic(v)?);
285 }
286 Dynamic::from(out)
287 }
288 })
289}
290
291fn dynamic_to_scalar_loose(d: Dynamic) -> Result<ScalarValue, FnError> {
292 if d.is_unit() {
293 return Ok(ScalarValue::Null);
294 }
295 if let Ok(b) = d.as_bool() {
296 return Ok(ScalarValue::Boolean(Some(b)));
297 }
298 if let Ok(i) = d.as_int() {
299 return Ok(ScalarValue::Int64(Some(i)));
300 }
301 if let Ok(f) = d.as_float() {
302 return Ok(ScalarValue::Float64(Some(f)));
303 }
304 if let Ok(s) = d.clone().into_string() {
305 return Ok(ScalarValue::Utf8(Some(s)));
306 }
307 let bytes = serde_json::to_string(&d).map_err(|e| FnError::new(0x13, e.to_string()))?;
309 Ok(ScalarValue::LargeUtf8(Some(bytes)))
310}
311
312#[must_use]
315pub fn rhai_state_fields() -> Vec<Field> {
316 vec![Field::new("rhai_state", DataType::LargeBinary, true)]
317}
318
319pub fn build_agg_signature(
322 args: &[String],
323 returns: &str,
324 determinism: &str,
325) -> Result<AggSignature, crate::error::RhaiError> {
326 use crate::wire_translate::{determinism_to_volatility, type_name_to_argtype};
327 let arg_types: Vec<ArgType> = args
328 .iter()
329 .map(|s| type_name_to_argtype(s))
330 .collect::<Result<_, _>>()?;
331 let return_type = match returns.trim().to_ascii_lowercase().as_str() {
335 "map" | "object" | "any" => ArgType::Primitive(DataType::LargeUtf8),
336 _ => type_name_to_argtype(returns)?,
337 };
338 Ok(AggSignature {
339 args: arg_types,
340 returns: return_type,
341 state_fields: rhai_state_fields(),
342 volatility: determinism_to_volatility(determinism),
343 supports_partial: true,
344 })
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use crate::engine::build_engine;
351 use crate::host_fns::RhaiHostFnRegistry;
352 use crate::manifest::compile;
353 use arrow_array::Float64Array;
354 use datafusion::logical_expr::Volatility;
355 use uni_plugin::{CapabilitySet, PluginId};
356
357 fn build_runtime(script: &str) -> Arc<RhaiPluginRuntime> {
358 let engine = build_engine(&CapabilitySet::new(), &RhaiHostFnRegistry::new());
359 let ast = compile(&engine, script).unwrap();
360 RhaiPluginRuntime::new(PluginId::new("test.agg"), engine, ast)
361 }
362
363 #[test]
364 fn stats_aggregate_round_trips() {
365 let script = r#"
366 fn stats_init() {
367 #{ n: 0, sum: 0.0, sum_sq: 0.0 }
368 }
369 fn stats_accumulate(state, x) {
370 state.n += 1;
371 state.sum += x;
372 state.sum_sq += x * x;
373 state
374 }
375 fn stats_merge(a, b) {
376 #{ n: a.n + b.n, sum: a.sum + b.sum, sum_sq: a.sum_sq + b.sum_sq }
377 }
378 fn stats_finalize(s) {
379 if s.n == 0 { return (); }
380 s.sum / s.n
381 }
382 "#;
383 let runtime = build_runtime(script);
384 let sig = AggSignature {
385 args: vec![ArgType::Primitive(DataType::Float64)],
386 returns: ArgType::Primitive(DataType::Float64),
387 state_fields: rhai_state_fields(),
388 volatility: Volatility::Immutable,
389 supports_partial: true,
390 };
391 let agg = RhaiAggregateFn::new(runtime, "stats", sig);
392 let mut acc = agg.create_accumulator();
393 let xs: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]));
394 acc.update_batch(&[xs]).unwrap();
395 let result = acc.evaluate().unwrap();
396 match result {
397 ScalarValue::Float64(Some(v)) => assert!((v - 2.5).abs() < 1e-9),
398 other => panic!("unexpected result: {other:?}"),
399 }
400 }
401
402 #[test]
403 fn state_serializes_and_merges() {
404 let script = r#"
405 fn sum_init() { 0.0 }
406 fn sum_accumulate(state, x) { state + x }
407 fn sum_merge(a, b) { a + b }
408 fn sum_finalize(s) { s }
409 "#;
410 let runtime = build_runtime(script);
411 let sig = AggSignature {
412 args: vec![ArgType::Primitive(DataType::Float64)],
413 returns: ArgType::Primitive(DataType::Float64),
414 state_fields: rhai_state_fields(),
415 volatility: Volatility::Immutable,
416 supports_partial: true,
417 };
418 let agg = RhaiAggregateFn::new(runtime, "sum", sig);
419
420 let mut a = agg.create_accumulator();
422 let xs1: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]));
423 a.update_batch(&[xs1]).unwrap();
424 let state_vec = a.state().unwrap();
425 let state_bytes = match &state_vec[0] {
426 ScalarValue::LargeBinary(Some(b)) => b.clone(),
427 other => panic!("expected LargeBinary, got {other:?}"),
428 };
429
430 let mut b = agg.create_accumulator();
432 let xs2: ArrayRef = Arc::new(Float64Array::from(vec![10.0, 20.0]));
433 b.update_batch(&[xs2]).unwrap();
434 let peer_arr: ArrayRef = Arc::new(LargeBinaryArray::from(vec![state_bytes.as_slice()]));
435 b.merge_batch(&[peer_arr]).unwrap();
436 let result = b.evaluate().unwrap();
437 match result {
438 ScalarValue::Float64(Some(v)) => assert!((v - 36.0).abs() < 1e-9),
439 other => panic!("unexpected result: {other:?}"),
440 }
441 }
442}