1use std::sync::Arc;
33
34use rustc_hash::FxHashMap;
35use varpulis_core::{Type, Value};
36
37#[derive(Debug, Clone)]
39pub enum TypeConstraint {
40 Exact(Type),
42 Numeric,
44 Any,
46 OneOf(Vec<Type>),
48}
49
50impl TypeConstraint {
51 pub fn matches(&self, ty: &Type) -> bool {
53 match self {
54 Self::Exact(expected) => ty == expected,
55 Self::Numeric => ty.is_numeric(),
56 Self::Any => true,
57 Self::OneOf(types) => types.contains(ty),
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct VariadicSpec {
65 pub min_args: usize,
67 pub arg_type: TypeConstraint,
69}
70
71#[derive(Debug, Clone)]
73pub struct Signature {
74 pub input_types: Vec<TypeConstraint>,
76 pub return_type: Type,
78 pub variadic: Option<VariadicSpec>,
80}
81
82pub trait ScalarUDF: Send + Sync {
86 fn name(&self) -> &str;
88 fn signature(&self) -> Signature;
90 fn evaluate(&self, args: &[Value]) -> Option<Value>;
92}
93
94pub trait AggregateUDF: Send + Sync {
99 fn name(&self) -> &str;
101 fn signature(&self) -> Signature;
103 fn init(&self) -> Box<dyn Accumulator>;
105}
106
107pub trait Accumulator: Send + Sync {
109 fn update(&mut self, value: &Value);
111 fn merge(&mut self, other: &dyn Accumulator);
113 fn finish(&self) -> Value;
115 fn reset(&mut self);
117 fn clone_box(&self) -> Box<dyn Accumulator>;
119}
120
121pub struct UdfRegistry {
126 scalar_udfs: FxHashMap<String, Arc<dyn ScalarUDF>>,
127 aggregate_udfs: FxHashMap<String, Arc<dyn AggregateUDF>>,
128}
129
130impl std::fmt::Debug for UdfRegistry {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 f.debug_struct("UdfRegistry")
133 .field("scalar_udfs", &self.scalar_udfs.keys().collect::<Vec<_>>())
134 .field(
135 "aggregate_udfs",
136 &self.aggregate_udfs.keys().collect::<Vec<_>>(),
137 )
138 .finish_non_exhaustive()
139 }
140}
141
142impl UdfRegistry {
143 pub fn new() -> Self {
144 Self {
145 scalar_udfs: FxHashMap::default(),
146 aggregate_udfs: FxHashMap::default(),
147 }
148 }
149
150 pub fn register_scalar(&mut self, udf: Arc<dyn ScalarUDF>) {
152 self.scalar_udfs.insert(udf.name().to_string(), udf);
153 }
154
155 pub fn register_aggregate(&mut self, udf: Arc<dyn AggregateUDF>) {
157 self.aggregate_udfs.insert(udf.name().to_string(), udf);
158 }
159
160 pub fn get_scalar(&self, name: &str) -> Option<&Arc<dyn ScalarUDF>> {
162 self.scalar_udfs.get(name)
163 }
164
165 pub fn get_aggregate(&self, name: &str) -> Option<&Arc<dyn AggregateUDF>> {
167 self.aggregate_udfs.get(name)
168 }
169
170 pub fn validate_call(&self, name: &str, arg_types: &[Type]) -> Result<Type, String> {
174 if let Some(udf) = self.scalar_udfs.get(name) {
175 let sig = udf.signature();
176 validate_signature(&sig, arg_types)?;
177 return Ok(sig.return_type);
178 }
179 if let Some(udf) = self.aggregate_udfs.get(name) {
180 let sig = udf.signature();
181 validate_signature(&sig, arg_types)?;
182 return Ok(sig.return_type);
183 }
184 Err(format!("unknown UDF: {name}"))
185 }
186
187 pub fn is_empty(&self) -> bool {
189 self.scalar_udfs.is_empty() && self.aggregate_udfs.is_empty()
190 }
191}
192
193impl Default for UdfRegistry {
194 fn default() -> Self {
195 Self::new()
196 }
197}
198
199fn validate_signature(sig: &Signature, arg_types: &[Type]) -> Result<(), String> {
200 let required = sig.input_types.len();
201
202 if let Some(variadic) = &sig.variadic {
203 if arg_types.len() < required + variadic.min_args {
204 return Err(format!(
205 "expected at least {} arguments, got {}",
206 required + variadic.min_args,
207 arg_types.len()
208 ));
209 }
210 for (i, constraint) in sig.input_types.iter().enumerate() {
212 if !constraint.matches(&arg_types[i]) {
213 return Err(format!(
214 "argument {} type mismatch: expected {:?}, got {:?}",
215 i, constraint, arg_types[i]
216 ));
217 }
218 }
219 for (i, ty) in arg_types[required..].iter().enumerate() {
221 if !variadic.arg_type.matches(ty) {
222 return Err(format!(
223 "variadic argument {} type mismatch: expected {:?}, got {:?}",
224 i, variadic.arg_type, ty
225 ));
226 }
227 }
228 } else {
229 if arg_types.len() != required {
230 return Err(format!(
231 "expected {} arguments, got {}",
232 required,
233 arg_types.len()
234 ));
235 }
236 for (i, constraint) in sig.input_types.iter().enumerate() {
237 if !constraint.matches(&arg_types[i]) {
238 return Err(format!(
239 "argument {} type mismatch: expected {:?}, got {:?}",
240 i, constraint, arg_types[i]
241 ));
242 }
243 }
244 }
245
246 Ok(())
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 struct DoubleUdf;
254 impl ScalarUDF for DoubleUdf {
255 fn name(&self) -> &'static str {
256 "double"
257 }
258 fn signature(&self) -> Signature {
259 Signature {
260 input_types: vec![TypeConstraint::Numeric],
261 return_type: Type::Float,
262 variadic: None,
263 }
264 }
265 fn evaluate(&self, args: &[Value]) -> Option<Value> {
266 match &args[0] {
267 Value::Int(i) => Some(Value::Float(*i as f64 * 2.0)),
268 Value::Float(f) => Some(Value::Float(f * 2.0)),
269 _ => None,
270 }
271 }
272 }
273
274 struct SumAccumulator {
275 total: f64,
276 }
277
278 impl Accumulator for SumAccumulator {
279 fn update(&mut self, value: &Value) {
280 match value {
281 Value::Int(i) => self.total += *i as f64,
282 Value::Float(f) => self.total += f,
283 _ => {}
284 }
285 }
286 fn merge(&mut self, other: &dyn Accumulator) {
287 let val = other.finish();
289 self.update(&val);
290 }
291 fn finish(&self) -> Value {
292 Value::Float(self.total)
293 }
294 fn reset(&mut self) {
295 self.total = 0.0;
296 }
297 fn clone_box(&self) -> Box<dyn Accumulator> {
298 Box::new(Self { total: self.total })
299 }
300 }
301
302 struct CustomSumUdf;
303 impl AggregateUDF for CustomSumUdf {
304 fn name(&self) -> &'static str {
305 "custom_sum"
306 }
307 fn signature(&self) -> Signature {
308 Signature {
309 input_types: vec![TypeConstraint::Numeric],
310 return_type: Type::Float,
311 variadic: None,
312 }
313 }
314 fn init(&self) -> Box<dyn Accumulator> {
315 Box::new(SumAccumulator { total: 0.0 })
316 }
317 }
318
319 #[test]
320 fn test_scalar_udf_evaluate() {
321 let udf = DoubleUdf;
322 assert_eq!(udf.evaluate(&[Value::Int(5)]), Some(Value::Float(10.0)));
323 assert_eq!(udf.evaluate(&[Value::Float(2.5)]), Some(Value::Float(5.0)));
324 assert_eq!(udf.evaluate(&[Value::Null]), None);
325 }
326
327 #[test]
328 fn test_registry_lookup() {
329 let mut registry = UdfRegistry::new();
330 assert!(registry.is_empty());
331
332 registry.register_scalar(Arc::new(DoubleUdf));
333 registry.register_aggregate(Arc::new(CustomSumUdf));
334
335 assert!(!registry.is_empty());
336 assert!(registry.get_scalar("double").is_some());
337 assert!(registry.get_aggregate("custom_sum").is_some());
338 assert!(registry.get_scalar("unknown").is_none());
339 }
340
341 #[test]
342 fn test_validate_call_valid() {
343 let mut registry = UdfRegistry::new();
344 registry.register_scalar(Arc::new(DoubleUdf));
345
346 let result = registry.validate_call("double", &[Type::Int]);
347 assert!(result.is_ok());
348 assert_eq!(result.unwrap(), Type::Float);
349 }
350
351 #[test]
352 fn test_validate_call_wrong_type() {
353 let mut registry = UdfRegistry::new();
354 registry.register_scalar(Arc::new(DoubleUdf));
355
356 let result = registry.validate_call("double", &[Type::Str]);
357 assert!(result.is_err());
358 }
359
360 #[test]
361 fn test_validate_call_wrong_arity() {
362 let mut registry = UdfRegistry::new();
363 registry.register_scalar(Arc::new(DoubleUdf));
364
365 let result = registry.validate_call("double", &[Type::Int, Type::Int]);
366 assert!(result.is_err());
367 }
368
369 #[test]
370 fn test_validate_unknown_udf() {
371 let registry = UdfRegistry::new();
372 let result = registry.validate_call("nonexistent", &[]);
373 assert!(result.is_err());
374 }
375
376 #[test]
377 fn test_accumulator_lifecycle() {
378 let udf = CustomSumUdf;
379 let mut acc = udf.init();
380
381 acc.update(&Value::Int(10));
382 acc.update(&Value::Float(5.5));
383 assert_eq!(acc.finish(), Value::Float(15.5));
384
385 let cloned = acc.clone_box();
386 assert_eq!(cloned.finish(), Value::Float(15.5));
387
388 acc.reset();
389 assert_eq!(acc.finish(), Value::Float(0.0));
390 }
391
392 #[test]
393 fn test_type_constraint_matches() {
394 assert!(TypeConstraint::Any.matches(&Type::Int));
395 assert!(TypeConstraint::Any.matches(&Type::Str));
396
397 assert!(TypeConstraint::Numeric.matches(&Type::Int));
398 assert!(TypeConstraint::Numeric.matches(&Type::Float));
399 assert!(!TypeConstraint::Numeric.matches(&Type::Str));
400
401 assert!(TypeConstraint::Exact(Type::Int).matches(&Type::Int));
402 assert!(!TypeConstraint::Exact(Type::Int).matches(&Type::Float));
403
404 assert!(TypeConstraint::OneOf(vec![Type::Int, Type::Str]).matches(&Type::Str));
405 assert!(!TypeConstraint::OneOf(vec![Type::Int, Type::Str]).matches(&Type::Bool));
406 }
407
408 #[test]
409 fn test_variadic_validation() {
410 let mut registry = UdfRegistry::new();
411
412 struct ConcatUdf;
413 impl ScalarUDF for ConcatUdf {
414 fn name(&self) -> &'static str {
415 "concat"
416 }
417 fn signature(&self) -> Signature {
418 Signature {
419 input_types: vec![],
420 return_type: Type::Str,
421 variadic: Some(VariadicSpec {
422 min_args: 1,
423 arg_type: TypeConstraint::Exact(Type::Str),
424 }),
425 }
426 }
427 fn evaluate(&self, _args: &[Value]) -> Option<Value> {
428 None
429 }
430 }
431
432 registry.register_scalar(Arc::new(ConcatUdf));
433
434 assert!(registry.validate_call("concat", &[Type::Str]).is_ok());
436 assert!(registry
438 .validate_call("concat", &[Type::Str, Type::Str, Type::Str])
439 .is_ok());
440 assert!(registry.validate_call("concat", &[]).is_err());
442 assert!(registry.validate_call("concat", &[Type::Int]).is_err());
444 }
445}