Skip to main content

uni_db/api/
impl_query.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4use crate::api::Uni;
5use crate::api::query_builder::QueryBuilder;
6use futures::StreamExt;
7use std::collections::HashMap;
8use std::sync::Arc;
9use uni_common::{Result, UniConfig, UniError};
10use uni_query::{
11    ExecuteResult, ExplainOutput, LogicalPlan, ProfileOutput, QueryCursor, QueryResult,
12    ResultNormalizer, Row, Value as ApiValue,
13};
14
15/// Normalize backend/planner error text into canonical Cypher/TCK codes.
16///
17/// This keeps behavioral semantics unchanged while making error classification
18/// stable across planner backends.
19fn normalize_error_message(raw: &str, cypher: &str) -> String {
20    let mut normalized = raw.to_string();
21    let cypher_upper = cypher.to_uppercase();
22    let cypher_lower = cypher.to_lowercase();
23
24    if raw.contains("Error during planning: UDF") && raw.contains("is not registered") {
25        normalized = format!("SyntaxError: UnknownFunction - {}", raw);
26    } else if raw.contains("_cypher_in(): second argument must be a list") {
27        normalized = format!("TypeError: InvalidArgumentType - {}", raw);
28    } else if raw.contains("InvalidNumberOfArguments: Procedure") && raw.contains("got 0") {
29        if cypher_upper.contains("YIELD") {
30            normalized = format!("SyntaxError: InvalidArgumentPassingMode - {}", raw);
31        } else {
32            normalized = format!("ParameterMissing: MissingParameter - {}", raw);
33        }
34    } else if raw.contains("Function count not implemented or is aggregate")
35        || raw.contains("Physical plan does not support logical expression AggregateFunction")
36        || raw.contains("Expected aggregate function, got: ListComprehension")
37    {
38        normalized = format!("SyntaxError: InvalidAggregation - {}", raw);
39    } else if raw.contains("Expected aggregate function, got: BinaryOp") {
40        normalized = format!("SyntaxError: AmbiguousAggregationExpression - {}", raw);
41    } else if raw.contains("Schema error: No field named \"me.age\". Valid fields are \"count(you.age)\".")
42    {
43        normalized = format!("SyntaxError: UndefinedVariable - {}", raw);
44    } else if raw.contains(
45        "Schema error: No field named \"me.age\". Valid fields are \"me.age + you.age\", \"count(*)\".",
46    ) {
47        normalized = format!("SyntaxError: AmbiguousAggregationExpression - {}", raw);
48    } else if raw.contains("MERGE edge must have a type")
49        || raw.contains("MERGE does not support multiple edge types")
50    {
51        normalized = format!("SyntaxError: NoSingleRelationshipType - {}", raw);
52    } else if raw.contains("MERGE node must have a label") {
53        if cypher.contains("$param") {
54            normalized = format!("SyntaxError: InvalidParameterUse - {}", raw);
55        } else if cypher.contains('*') && cypher.contains("-[:") {
56            normalized = format!("SyntaxError: CreatingVarLength - {}", raw);
57        } else if cypher_lower.contains("on create set x.")
58            || cypher_lower.contains("on match set x.")
59        {
60            normalized = format!("SyntaxError: UndefinedVariable - {}", raw);
61        }
62    }
63
64    normalized
65}
66
67/// Convert a parse error into `UniError::Parse`.
68fn into_parse_error(e: impl std::fmt::Display) -> UniError {
69    UniError::Parse {
70        message: e.to_string(),
71        position: None,
72        line: None,
73        column: None,
74        context: None,
75    }
76}
77
78/// Convert a planner/compile-time error into the appropriate `UniError` type.
79/// Errors starting with "SyntaxError:" are treated as parse/syntax errors.
80/// All other errors are query/semantic errors (CompileTime).
81fn into_query_error(e: impl std::fmt::Display, cypher: &str) -> UniError {
82    let msg = normalize_error_message(&e.to_string(), cypher);
83    // Errors containing "SyntaxError:" prefix should be treated as syntax errors
84    // This covers validation errors like VariableTypeConflict, UndefinedVariable, etc.
85    if msg.starts_with("SyntaxError:") {
86        UniError::Parse {
87            message: msg,
88            position: None,
89            line: None,
90            column: None,
91            context: Some(cypher.to_string()),
92        }
93    } else {
94        UniError::Query {
95            message: msg,
96            query: Some(cypher.to_string()),
97        }
98    }
99}
100
101/// Convert an executor/runtime error into the appropriate `UniError` type.
102/// TypeError messages from UDF execution become `UniError::Type` (Runtime phase).
103/// ConstraintVerificationFailed messages become `UniError::Constraint` (Runtime phase).
104/// All other executor errors remain `UniError::Query`.
105fn into_execution_error(e: impl std::fmt::Display, cypher: &str) -> UniError {
106    let msg = normalize_error_message(&e.to_string(), cypher);
107    if msg.contains("TypeError:") {
108        UniError::Type {
109            expected: msg,
110            actual: String::new(),
111        }
112    } else if msg.starts_with("ConstraintVerificationFailed:") {
113        UniError::Constraint { message: msg }
114    } else {
115        UniError::Query {
116            message: msg,
117            query: Some(cypher.to_string()),
118        }
119    }
120}
121
122/// Extract projection column names from a LogicalPlan, preserving query order.
123/// Returns None if the plan doesn't have projections at the top level.
124fn extract_projection_order(plan: &LogicalPlan) -> Option<Vec<String>> {
125    match plan {
126        LogicalPlan::Project { projections, .. } => Some(
127            projections
128                .iter()
129                .map(|(expr, alias)| alias.clone().unwrap_or_else(|| expr.to_string_repr()))
130                .collect(),
131        ),
132        LogicalPlan::Aggregate {
133            group_by,
134            aggregates,
135            ..
136        } => {
137            let mut names: Vec<String> = group_by.iter().map(|e| e.to_string_repr()).collect();
138            names.extend(aggregates.iter().map(|e| e.to_string_repr()));
139            Some(names)
140        }
141        LogicalPlan::Limit { input, .. }
142        | LogicalPlan::Sort { input, .. }
143        | LogicalPlan::Filter { input, .. } => extract_projection_order(input),
144        _ => None,
145    }
146}
147
148impl Uni {
149    /// Get the current L0Buffer mutation count (cumulative mutations since last flush).
150    /// Used to compute affected_rows for mutation queries that return no result rows.
151    pub(crate) async fn get_mutation_count(&self) -> usize {
152        match self.writer.as_ref() {
153            Some(w) => {
154                let writer = w.read().await;
155                writer.l0_manager.get_current().read().mutation_count
156            }
157            None => 0,
158        }
159    }
160
161    /// Explain a Cypher query plan without executing it.
162    pub async fn explain(&self, cypher: &str) -> Result<ExplainOutput> {
163        let ast = uni_cypher::parse(cypher).map_err(into_parse_error)?;
164
165        let planner = uni_query::QueryPlanner::new(self.schema.schema().clone());
166        planner
167            .explain_plan(ast)
168            .map_err(|e| into_query_error(e, cypher))
169    }
170
171    /// Profile a Cypher query execution.
172    pub async fn profile(&self, cypher: &str) -> Result<(QueryResult, ProfileOutput)> {
173        let ast = uni_cypher::parse(cypher).map_err(into_parse_error)?;
174
175        let planner = uni_query::QueryPlanner::new(self.schema.schema().clone());
176        let logical_plan = planner.plan(ast).map_err(|e| into_query_error(e, cypher))?;
177
178        let mut executor = uni_query::Executor::new(self.storage.clone());
179        executor.set_config(self.config.clone());
180        executor.set_xervo_runtime(self.xervo_runtime.clone());
181        executor.set_procedure_registry(self.procedure_registry.clone());
182        if let Some(w) = &self.writer {
183            executor.set_writer(w.clone());
184        }
185
186        let params: HashMap<String, uni_common::Value> = HashMap::new(); // TODO: Support params in profile
187
188        // Extract projection order
189        let projection_order = extract_projection_order(&logical_plan);
190
191        let (results, profile_output) = executor
192            .profile(logical_plan, &params)
193            .await
194            .map_err(|e| into_execution_error(e, cypher))?;
195
196        // Convert results to QueryResult
197        let columns = if results.is_empty() {
198            Arc::new(vec![])
199        } else if let Some(order) = projection_order {
200            Arc::new(order)
201        } else {
202            let mut cols: Vec<String> = results[0].keys().cloned().collect();
203            cols.sort();
204            Arc::new(cols)
205        };
206
207        let rows = results
208            .into_iter()
209            .map(|map| {
210                let mut values = Vec::with_capacity(columns.len());
211                for col in columns.iter() {
212                    let value = map.get(col).cloned().unwrap_or(ApiValue::Null);
213                    // Normalize to ensure proper Node/Edge/Path types
214                    let normalized =
215                        ResultNormalizer::normalize_value(value).unwrap_or(ApiValue::Null);
216                    values.push(normalized);
217                }
218                Row {
219                    columns: columns.clone(),
220                    values,
221                }
222            })
223            .collect();
224
225        Ok((
226            QueryResult {
227                columns,
228                rows,
229                warnings: Vec::new(),
230            },
231            profile_output,
232        ))
233    }
234
235    /// Execute a Cypher query
236    pub async fn query(&self, cypher: &str) -> Result<QueryResult> {
237        self.execute_internal(cypher, HashMap::new()).await
238    }
239
240    /// Execute query returning a cursor for streaming results
241    pub async fn query_cursor(&self, cypher: &str) -> Result<QueryCursor> {
242        self.execute_cursor_internal(cypher, HashMap::new()).await
243    }
244
245    /// Execute a query with parameters using a builder
246    pub fn query_with(&self, cypher: &str) -> QueryBuilder<'_> {
247        QueryBuilder::new(self, cypher)
248    }
249
250    /// Execute a mutation with parameters using a builder.
251    ///
252    /// Alias for [`query_with`](Self::query_with) that clarifies intent for
253    /// mutation queries. Use `.param()` to bind parameters, then `.execute()`
254    /// to run the mutation.
255    ///
256    /// # Examples
257    ///
258    /// ```no_run
259    /// # use uni_db::Uni;
260    /// # async fn example(db: &Uni) -> uni_db::Result<()> {
261    /// db.execute_with("CREATE (p:Person {name: $name, age: $age})")
262    ///     .param("name", "Alice")
263    ///     .param("age", 30)
264    ///     .execute()
265    ///     .await?;
266    /// # Ok(())
267    /// # }
268    /// ```
269    pub fn execute_with(&self, cypher: &str) -> QueryBuilder<'_> {
270        self.query_with(cypher)
271    }
272
273    /// Execute a modification query (CREATE, SET, DELETE, etc.)
274    /// Returns the number of affected rows/elements
275    pub async fn execute(&self, cypher: &str) -> Result<ExecuteResult> {
276        let before = self.get_mutation_count().await;
277        let result = self.execute_internal(cypher, HashMap::new()).await?;
278        let affected_rows = if result.is_empty() {
279            self.get_mutation_count().await.saturating_sub(before)
280        } else {
281            result.len()
282        };
283        Ok(ExecuteResult { affected_rows })
284    }
285
286    pub(crate) async fn execute_cursor_internal(
287        &self,
288        cypher: &str,
289        params: HashMap<String, ApiValue>,
290    ) -> Result<QueryCursor> {
291        self.execute_cursor_internal_with_config(cypher, params, self.config.clone())
292            .await
293    }
294
295    pub(crate) async fn execute_cursor_internal_with_config(
296        &self,
297        cypher: &str,
298        params: HashMap<String, ApiValue>,
299        config: UniConfig,
300    ) -> Result<QueryCursor> {
301        let ast = uni_cypher::parse(cypher).map_err(into_parse_error)?;
302
303        let planner =
304            uni_query::QueryPlanner::new(self.schema.schema().clone()).with_params(params.clone());
305        let logical_plan = planner.plan(ast).map_err(|e| into_query_error(e, cypher))?;
306
307        let mut executor = uni_query::Executor::new(self.storage.clone());
308        executor.set_config(config.clone());
309        executor.set_xervo_runtime(self.xervo_runtime.clone());
310        executor.set_procedure_registry(self.procedure_registry.clone());
311        if let Some(w) = &self.writer {
312            executor.set_writer(w.clone());
313        }
314
315        let projection_order = extract_projection_order(&logical_plan);
316        let projection_order_for_rows = projection_order.clone();
317        let cypher_for_error = cypher.to_string();
318
319        let stream = executor.execute_stream(logical_plan, self.properties.clone(), params);
320
321        let row_stream = stream.map(move |batch_res| {
322            let results = batch_res.map_err(|e| {
323                let msg = normalize_error_message(&e.to_string(), &cypher_for_error);
324                if msg.contains("TypeError:") {
325                    UniError::Type {
326                        expected: msg,
327                        actual: String::new(),
328                    }
329                } else if msg.starts_with("ConstraintVerificationFailed:") {
330                    UniError::Constraint { message: msg }
331                } else {
332                    UniError::Query {
333                        message: msg,
334                        query: Some(cypher_for_error.clone()),
335                    }
336                }
337            })?;
338
339            if results.is_empty() {
340                return Ok(vec![]);
341            }
342
343            // Determine columns for this batch (should be stable for the whole query)
344            let columns = if let Some(order) = &projection_order_for_rows {
345                Arc::new(order.clone())
346            } else {
347                let mut cols: Vec<String> = results[0].keys().cloned().collect();
348                cols.sort();
349                Arc::new(cols)
350            };
351
352            let rows = results
353                .into_iter()
354                .map(|map| {
355                    let mut values = Vec::with_capacity(columns.len());
356                    for col in columns.iter() {
357                        let value = map.get(col).cloned().unwrap_or(ApiValue::Null);
358                        values.push(value);
359                    }
360                    Row {
361                        columns: columns.clone(),
362                        values,
363                    }
364                })
365                .collect();
366
367            Ok(rows)
368        });
369
370        // We need columns ahead of time for QueryCursor if possible.
371        let columns = if let Some(order) = projection_order {
372            Arc::new(order)
373        } else {
374            Arc::new(vec![])
375        };
376
377        Ok(QueryCursor {
378            columns,
379            stream: Box::pin(row_stream),
380        })
381    }
382
383    pub(crate) async fn execute_internal(
384        &self,
385        cypher: &str,
386        params: HashMap<String, ApiValue>,
387    ) -> Result<QueryResult> {
388        self.execute_internal_with_config(cypher, params, self.config.clone())
389            .await
390    }
391
392    pub(crate) async fn execute_internal_with_config(
393        &self,
394        cypher: &str,
395        params: HashMap<String, ApiValue>,
396        config: UniConfig,
397    ) -> Result<QueryResult> {
398        // Single parse: extract time-travel clause if present
399        let ast = uni_cypher::parse(cypher).map_err(into_parse_error)?;
400        let (ast, tt_spec) = match ast {
401            uni_cypher::ast::Query::TimeTravel { query, spec } => (*query, Some(spec)),
402            other => (other, None),
403        };
404
405        if let Some(spec) = tt_spec {
406            uni_query::validate_read_only(&ast).map_err(|msg| into_query_error(msg, cypher))?;
407            // Resolve to snapshot and execute on pinned instance
408            let snapshot_id = self.resolve_time_travel(&spec).await?;
409            let pinned = self.at_snapshot(&snapshot_id).await?;
410            return pinned
411                .execute_ast_internal(ast, cypher, params, config)
412                .await;
413        }
414
415        self.execute_ast_internal(ast, cypher, params, config).await
416    }
417
418    /// Execute a pre-parsed Cypher AST through the planner and executor.
419    ///
420    /// The `cypher` parameter is the original query string, used only for
421    /// error messages.
422    pub(crate) async fn execute_ast_internal(
423        &self,
424        ast: uni_query::CypherQuery,
425        cypher: &str,
426        params: HashMap<String, ApiValue>,
427        config: UniConfig,
428    ) -> Result<QueryResult> {
429        let planner =
430            uni_query::QueryPlanner::new(self.schema.schema().clone()).with_params(params.clone());
431        let logical_plan = planner.plan(ast).map_err(|e| into_query_error(e, cypher))?;
432
433        let mut executor = uni_query::Executor::new(self.storage.clone());
434        executor.set_config(config.clone());
435        executor.set_xervo_runtime(self.xervo_runtime.clone());
436        executor.set_procedure_registry(self.procedure_registry.clone());
437        if let Some(w) = &self.writer {
438            executor.set_writer(w.clone());
439        }
440
441        let projection_order = extract_projection_order(&logical_plan);
442
443        let results = executor
444            .execute(logical_plan, &self.properties, &params)
445            .await
446            .map_err(|e| into_execution_error(e, cypher))?;
447
448        let columns = if results.is_empty() {
449            Arc::new(vec![])
450        } else if let Some(order) = projection_order {
451            Arc::new(order)
452        } else {
453            let mut cols: Vec<String> = results[0].keys().cloned().collect();
454            cols.sort();
455            Arc::new(cols)
456        };
457
458        let rows = results
459            .into_iter()
460            .map(|map| {
461                let mut values = Vec::with_capacity(columns.len());
462                for col in columns.iter() {
463                    let value = map.get(col).cloned().unwrap_or(ApiValue::Null);
464                    // Normalize to ensure proper Node/Edge/Path types
465                    let normalized =
466                        ResultNormalizer::normalize_value(value).unwrap_or(ApiValue::Null);
467                    values.push(normalized);
468                }
469                Row {
470                    columns: columns.clone(),
471                    values,
472                }
473            })
474            .collect();
475
476        Ok(QueryResult {
477            columns,
478            rows,
479            warnings: executor.take_warnings(),
480        })
481    }
482
483    /// Resolve a time-travel spec to a snapshot ID.
484    async fn resolve_time_travel(&self, spec: &uni_query::TimeTravelSpec) -> Result<String> {
485        match spec {
486            uni_query::TimeTravelSpec::Version(id) => Ok(id.clone()),
487            uni_query::TimeTravelSpec::Timestamp(ts_str) => {
488                let ts = chrono::DateTime::parse_from_rfc3339(ts_str)
489                    .map_err(|e| {
490                        into_parse_error(format!("Invalid timestamp '{}': {}", ts_str, e))
491                    })?
492                    .with_timezone(&chrono::Utc);
493                let manifest = self
494                    .storage
495                    .snapshot_manager()
496                    .find_snapshot_at_time(ts)
497                    .await
498                    .map_err(UniError::Internal)?
499                    .ok_or_else(|| UniError::Query {
500                        message: format!("No snapshot found at or before {}", ts_str),
501                        query: None,
502                    })?;
503                Ok(manifest.snapshot_id)
504            }
505        }
506    }
507}