Skip to main content

uni_query/query/executor/
procedure.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4use super::core::*;
5use anyhow::{Result, anyhow};
6use std::collections::HashMap;
7use uni_common::Value;
8use uni_cypher::ast::Expr;
9use uni_store::QueryContext;
10use uni_store::runtime::property_manager::PropertyManager;
11
12fn success_result(success: bool) -> Result<Vec<HashMap<String, Value>>> {
13    Ok(vec![HashMap::from([(
14        "success".to_string(),
15        Value::Bool(success),
16    )])])
17}
18
19/// Value type for procedure parameters and outputs.
20#[derive(Debug, Clone, PartialEq)]
21pub enum ProcedureValueType {
22    /// Cypher STRING type.
23    String,
24    /// Cypher INTEGER type.
25    Integer,
26    /// Cypher FLOAT type.
27    Float,
28    /// Cypher NUMBER type (accepts both INTEGER and FLOAT).
29    Number,
30    /// Cypher BOOLEAN type.
31    Boolean,
32    /// Accepts any value type.
33    Any,
34}
35
36/// Single parameter declaration for a registered procedure.
37#[derive(Debug, Clone)]
38pub struct ProcedureParam {
39    /// Parameter name as declared in the procedure signature.
40    pub name: String,
41    /// Expected type for this parameter.
42    pub param_type: ProcedureValueType,
43}
44
45/// Single output column declaration for a registered procedure.
46#[derive(Debug, Clone)]
47pub struct ProcedureOutput {
48    /// Output column name as declared in the procedure signature.
49    pub name: String,
50    /// Type of the output column.
51    pub output_type: ProcedureValueType,
52}
53
54/// A procedure registered at runtime with static mock data.
55///
56/// Used by the TCK harness to define test procedures that the query
57/// engine can call via `CALL proc.name(args) YIELD columns`.
58#[derive(Debug, Clone)]
59pub struct RegisteredProcedure {
60    /// Fully qualified procedure name (e.g. `test.my.proc`).
61    pub name: String,
62    /// Declared input parameters.
63    pub params: Vec<ProcedureParam>,
64    /// Declared output columns.
65    pub outputs: Vec<ProcedureOutput>,
66    /// Mock data rows keyed by column name.
67    pub data: Vec<HashMap<String, Value>>,
68}
69
70/// Thread-safe registry of test procedures.
71///
72/// Procedures are registered before query execution (typically by TCK
73/// step definitions) and looked up by the executor at runtime.
74#[derive(Debug, Default)]
75pub struct ProcedureRegistry {
76    procedures: std::sync::RwLock<HashMap<String, RegisteredProcedure>>,
77}
78
79impl ProcedureRegistry {
80    /// Creates an empty registry.
81    pub fn new() -> Self {
82        Self::default()
83    }
84
85    /// Registers a procedure, replacing any existing one with the same name.
86    pub fn register(&self, proc_def: RegisteredProcedure) {
87        self.procedures
88            .write()
89            .expect("ProcedureRegistry lock poisoned")
90            .insert(proc_def.name.clone(), proc_def);
91    }
92
93    /// Looks up a procedure by fully qualified name.
94    pub fn get(&self, name: &str) -> Option<RegisteredProcedure> {
95        self.procedures
96            .read()
97            .expect("ProcedureRegistry lock poisoned")
98            .get(name)
99            .cloned()
100    }
101
102    /// Removes all registered procedures.
103    pub fn clear(&self) {
104        self.procedures
105            .write()
106            .expect("ProcedureRegistry lock poisoned")
107            .clear();
108    }
109}
110
111/// Filters a full result map to only the requested yield items.
112/// If `yield_items` is empty, returns the full result unchanged.
113fn filter_yield_items(
114    full_result: HashMap<String, Value>,
115    yield_items: &[String],
116) -> HashMap<String, Value> {
117    if yield_items.is_empty() {
118        return full_result;
119    }
120    yield_items
121        .iter()
122        .filter_map(|name| full_result.get(name).map(|val| (name.clone(), val.clone())))
123        .collect()
124}
125
126impl Executor {
127    /// Evaluate a procedure argument as a string, returning an error with the given description.
128    async fn eval_string_arg<'a>(
129        &'a self,
130        arg: &Expr,
131        description: &str,
132        prop_manager: &'a PropertyManager,
133        params: &'a HashMap<String, Value>,
134        ctx: Option<&'a QueryContext>,
135    ) -> Result<String> {
136        let empty_row = HashMap::new();
137        self.evaluate_expr(arg, &empty_row, prop_manager, params, ctx)
138            .await?
139            .as_str()
140            .ok_or_else(|| anyhow!("{} must be string", description))
141            .map(|s| s.to_string())
142    }
143
144    pub(crate) async fn execute_procedure<'a>(
145        &'a self,
146        name: &str,
147        args: &[Expr],
148        yield_items: &[String],
149        prop_manager: &'a PropertyManager,
150        params: &'a HashMap<String, Value>,
151        ctx: Option<&'a QueryContext>,
152    ) -> Result<Vec<HashMap<String, Value>>> {
153        match name {
154            "uni.admin.compact" => {
155                let stats = self.storage.compact().await?;
156                let full_result = HashMap::from([
157                    (
158                        "files_compacted".to_string(),
159                        Value::Int(stats.files_compacted as i64),
160                    ),
161                    (
162                        "bytes_before".to_string(),
163                        Value::Int(stats.bytes_before as i64),
164                    ),
165                    (
166                        "bytes_after".to_string(),
167                        Value::Int(stats.bytes_after as i64),
168                    ),
169                    (
170                        "duration_ms".to_string(),
171                        Value::Int(stats.duration.as_millis() as i64),
172                    ),
173                ]);
174
175                Ok(vec![filter_yield_items(full_result, yield_items)])
176            }
177            "uni.admin.compactionStatus" => {
178                let status = self
179                    .storage
180                    .compaction_status()
181                    .map_err(|e| anyhow::anyhow!("Failed to get compaction status: {}", e))?;
182                let full_result = HashMap::from([
183                    ("l1_runs".to_string(), Value::Int(status.l1_runs as i64)),
184                    (
185                        "l1_size_bytes".to_string(),
186                        Value::Int(status.l1_size_bytes as i64),
187                    ),
188                    (
189                        "in_progress".to_string(),
190                        Value::Bool(status.compaction_in_progress),
191                    ),
192                    (
193                        "pending".to_string(),
194                        Value::Int(status.compaction_pending as i64),
195                    ),
196                    (
197                        "total_compactions".to_string(),
198                        Value::Int(status.total_compactions as i64),
199                    ),
200                    (
201                        "total_bytes_compacted".to_string(),
202                        Value::Int(status.total_bytes_compacted as i64),
203                    ),
204                ]);
205
206                Ok(vec![filter_yield_items(full_result, yield_items)])
207            }
208            "uni.admin.snapshot.create" => {
209                let name = if !args.is_empty() {
210                    Some(
211                        self.eval_string_arg(&args[0], "Snapshot name", prop_manager, params, ctx)
212                            .await?,
213                    )
214                } else {
215                    None
216                };
217
218                let writer_arc = self
219                    .writer
220                    .as_ref()
221                    .ok_or_else(|| anyhow!("Database is in read-only mode"))?;
222                let mut writer = writer_arc.write().await;
223                let snapshot_id = writer.flush_to_l1(name).await?;
224
225                Ok(vec![HashMap::from([(
226                    "snapshot_id".to_string(),
227                    Value::String(snapshot_id),
228                )])])
229            }
230            "uni.admin.snapshot.list" => {
231                let sm = self.storage.snapshot_manager();
232                let ids = sm.list_snapshots().await?;
233                let mut results = Vec::new();
234                for id in ids {
235                    if let Ok(m) = sm.load_snapshot(&id).await {
236                        results.push(HashMap::from([
237                            ("snapshot_id".to_string(), Value::String(m.snapshot_id)),
238                            (
239                                "name".to_string(),
240                                m.name.map(Value::String).unwrap_or(Value::Null),
241                            ),
242                            (
243                                "created_at".to_string(),
244                                Value::String(m.created_at.to_rfc3339()),
245                            ),
246                            (
247                                "version_hwm".to_string(),
248                                Value::Int(m.version_high_water_mark as i64),
249                            ),
250                        ]));
251                    }
252                }
253                Ok(results)
254            }
255            "uni.admin.snapshot.restore" => {
256                let id = self
257                    .eval_string_arg(&args[0], "Snapshot ID", prop_manager, params, ctx)
258                    .await?;
259
260                self.storage
261                    .snapshot_manager()
262                    .set_latest_snapshot(&id)
263                    .await?;
264                Ok(vec![HashMap::from([(
265                    "status".to_string(),
266                    Value::String("Restored".to_string()),
267                )])])
268            }
269            // DDL Procedures
270            "uni.schema.createLabel" => {
271                let empty_row = HashMap::new();
272                let name = self
273                    .eval_string_arg(&args[0], "Label name", prop_manager, params, ctx)
274                    .await?;
275                let config = self
276                    .evaluate_expr(&args[1], &empty_row, prop_manager, params, ctx)
277                    .await?;
278
279                let success =
280                    super::ddl_procedures::create_label(&self.storage, &name, &config).await?;
281                success_result(success)
282            }
283            "uni.schema.createEdgeType" => {
284                let empty_row = HashMap::new();
285                let name = self
286                    .eval_string_arg(&args[0], "Edge type name", prop_manager, params, ctx)
287                    .await?;
288                let src_val = self
289                    .evaluate_expr(&args[1], &empty_row, prop_manager, params, ctx)
290                    .await?;
291                let dst_val = self
292                    .evaluate_expr(&args[2], &empty_row, prop_manager, params, ctx)
293                    .await?;
294                let config = self
295                    .evaluate_expr(&args[3], &empty_row, prop_manager, params, ctx)
296                    .await?;
297
298                // Convert src/dst to Vec<String>
299                let src_labels = src_val
300                    .as_array()
301                    .ok_or(anyhow!("Source labels must be a list"))?
302                    .iter()
303                    .map(|v| {
304                        v.as_str()
305                            .map(|s| s.to_string())
306                            .ok_or(anyhow!("Label must be string"))
307                    })
308                    .collect::<Result<Vec<_>>>()?;
309                let dst_labels = dst_val
310                    .as_array()
311                    .ok_or(anyhow!("Target labels must be a list"))?
312                    .iter()
313                    .map(|v| {
314                        v.as_str()
315                            .map(|s| s.to_string())
316                            .ok_or(anyhow!("Label must be string"))
317                    })
318                    .collect::<Result<Vec<_>>>()?;
319
320                let success = super::ddl_procedures::create_edge_type(
321                    &self.storage,
322                    &name,
323                    src_labels,
324                    dst_labels,
325                    &config,
326                )
327                .await?;
328                success_result(success)
329            }
330            "uni.schema.createIndex" => {
331                let empty_row = HashMap::new();
332                let label = self
333                    .eval_string_arg(&args[0], "Label", prop_manager, params, ctx)
334                    .await?;
335                let property = self
336                    .eval_string_arg(&args[1], "Property", prop_manager, params, ctx)
337                    .await?;
338                let config = self
339                    .evaluate_expr(&args[2], &empty_row, prop_manager, params, ctx)
340                    .await?;
341
342                let success =
343                    super::ddl_procedures::create_index(&self.storage, &label, &property, &config)
344                        .await?;
345                success_result(success)
346            }
347            "uni.schema.createConstraint" => {
348                let label = self
349                    .eval_string_arg(&args[0], "Label", prop_manager, params, ctx)
350                    .await?;
351                let c_type = self
352                    .eval_string_arg(&args[1], "Constraint type", prop_manager, params, ctx)
353                    .await?;
354                let empty_row = HashMap::new();
355                let props_val = self
356                    .evaluate_expr(&args[2], &empty_row, prop_manager, params, ctx)
357                    .await?;
358
359                let properties = props_val
360                    .as_array()
361                    .ok_or(anyhow!("Properties must be a list"))?
362                    .iter()
363                    .map(|v| {
364                        v.as_str()
365                            .map(|s| s.to_string())
366                            .ok_or(anyhow!("Property must be string"))
367                    })
368                    .collect::<Result<Vec<_>>>()?;
369
370                let success = super::ddl_procedures::create_constraint(
371                    &self.storage,
372                    &label,
373                    &c_type,
374                    properties,
375                )
376                .await?;
377                success_result(success)
378            }
379            "uni.schema.dropLabel" => {
380                let name = self
381                    .eval_string_arg(&args[0], "Label name", prop_manager, params, ctx)
382                    .await?;
383                let success = super::ddl_procedures::drop_label(&self.storage, &name).await?;
384                success_result(success)
385            }
386            "uni.schema.dropEdgeType" => {
387                let name = self
388                    .eval_string_arg(&args[0], "Edge type name", prop_manager, params, ctx)
389                    .await?;
390                let success = super::ddl_procedures::drop_edge_type(&self.storage, &name).await?;
391                success_result(success)
392            }
393            "uni.schema.dropIndex" => {
394                let name = self
395                    .eval_string_arg(&args[0], "Index name", prop_manager, params, ctx)
396                    .await?;
397                let success = super::ddl_procedures::drop_index(&self.storage, &name).await?;
398                success_result(success)
399            }
400            "uni.schema.dropConstraint" => {
401                let name = self
402                    .eval_string_arg(&args[0], "Constraint name", prop_manager, params, ctx)
403                    .await?;
404                let success = super::ddl_procedures::drop_constraint(&self.storage, &name).await?;
405                success_result(success)
406            }
407            _ => {
408                // Check external procedure registry
409                if let Some(registry) = &self.procedure_registry
410                    && let Some(proc_def) = registry.get(name)
411                {
412                    return self
413                        .execute_registered_procedure(
414                            &proc_def,
415                            args,
416                            yield_items,
417                            prop_manager,
418                            params,
419                            ctx,
420                        )
421                        .await;
422                }
423                Err(anyhow!("ProcedureNotFound: Unknown procedure '{}'", name))
424            }
425        }
426    }
427
428    /// Executes a procedure from the external registry.
429    ///
430    /// Evaluates arguments, validates count and types against the procedure
431    /// declaration, filters data rows by matching input columns, and projects
432    /// the requested output columns.
433    ///
434    /// # Errors
435    ///
436    /// Returns `InvalidNumberOfArguments` if the argument count is wrong,
437    /// or `InvalidArgumentType` if an argument has an incompatible type.
438    async fn execute_registered_procedure<'a>(
439        &'a self,
440        proc_def: &RegisteredProcedure,
441        args: &[Expr],
442        yield_items: &[String],
443        prop_manager: &'a PropertyManager,
444        params: &'a HashMap<String, Value>,
445        ctx: Option<&'a QueryContext>,
446    ) -> Result<Vec<HashMap<String, Value>>> {
447        let empty_row = HashMap::new();
448
449        // Evaluate arguments
450        let mut evaluated_args = Vec::with_capacity(args.len());
451        for arg in args {
452            evaluated_args.push(
453                self.evaluate_expr(arg, &empty_row, prop_manager, params, ctx)
454                    .await?,
455            );
456        }
457
458        // Validate argument count
459        if evaluated_args.len() != proc_def.params.len() {
460            if evaluated_args.is_empty() && !proc_def.params.is_empty() {
461                if yield_items.is_empty() {
462                    // Standalone CALL — resolve implicit arguments from query parameters
463                    let mut resolved = Vec::with_capacity(proc_def.params.len());
464                    for param in &proc_def.params {
465                        if let Some(val) = params.get(&param.name) {
466                            resolved.push(val.clone());
467                        } else {
468                            return Err(anyhow!(
469                                "MissingParameter: Procedure '{}' requires implicit argument '{}' \
470                                 but it was not provided as a query parameter",
471                                proc_def.name,
472                                param.name
473                            ));
474                        }
475                    }
476                    evaluated_args = resolved;
477                } else {
478                    // In-query CALL with YIELD cannot use implicit arguments
479                    return Err(anyhow!(
480                        "InvalidArgumentPassingMode: Procedure '{}' requires explicit argument passing in in-query CALL",
481                        proc_def.name
482                    ));
483                }
484            } else {
485                return Err(anyhow!(
486                    "InvalidNumberOfArguments: Procedure '{}' expects {} argument(s), got {}",
487                    proc_def.name,
488                    proc_def.params.len(),
489                    evaluated_args.len()
490                ));
491            }
492        }
493
494        // Validate argument types
495        for (i, (arg_val, param)) in evaluated_args.iter().zip(&proc_def.params).enumerate() {
496            if !arg_val.is_null() && !check_type_compatible(arg_val, &param.param_type) {
497                return Err(anyhow!(
498                    "InvalidArgumentType: Argument {} ('{}') of procedure '{}' has incompatible type",
499                    i,
500                    param.name,
501                    proc_def.name
502                ));
503            }
504        }
505
506        // Filter data rows: keep rows where input columns match the provided args
507        let filtered: Vec<&HashMap<String, Value>> = proc_def
508            .data
509            .iter()
510            .filter(|row| {
511                for (param, arg_val) in proc_def.params.iter().zip(&evaluated_args) {
512                    if let Some(row_val) = row.get(&param.name)
513                        && !values_match(row_val, arg_val)
514                    {
515                        return false;
516                    }
517                }
518                true
519            })
520            .collect();
521
522        // Collect output column names
523        let output_names: Vec<&str> = proc_def.outputs.iter().map(|o| o.name.as_str()).collect();
524
525        // Project output columns, applying yield_items filtering
526        let results = filtered
527            .into_iter()
528            .map(|row| {
529                let mut result = HashMap::new();
530                if yield_items.is_empty() {
531                    // Return all output columns
532                    for name in &output_names {
533                        if let Some(val) = row.get(*name) {
534                            result.insert((*name).to_string(), val.clone());
535                        }
536                    }
537                } else {
538                    for yield_name in yield_items {
539                        if let Some(val) = row.get(yield_name.as_str()) {
540                            result.insert(yield_name.clone(), val.clone());
541                        }
542                    }
543                }
544                result
545            })
546            .collect();
547
548        Ok(results)
549    }
550}
551
552/// Checks whether a value is compatible with a procedure type.
553fn check_type_compatible(val: &Value, expected: &ProcedureValueType) -> bool {
554    match expected {
555        ProcedureValueType::Any => true,
556        ProcedureValueType::String => val.is_string(),
557        ProcedureValueType::Boolean => val.is_bool(),
558        ProcedureValueType::Integer => val.is_i64(),
559        ProcedureValueType::Float => val.is_f64() || val.is_i64(),
560        ProcedureValueType::Number => val.is_number(),
561    }
562}
563
564/// Checks whether two values match for input-column filtering.
565fn values_match(row_val: &Value, arg_val: &Value) -> bool {
566    if arg_val.is_null() || row_val.is_null() {
567        return arg_val.is_null() && row_val.is_null();
568    }
569    // Compare numbers by f64 to handle int/float cross-comparison
570    if let (Some(a), Some(b)) = (row_val.as_f64(), arg_val.as_f64()) {
571        return (a - b).abs() < f64::EPSILON;
572    }
573    row_val == arg_val
574}