ruvector_graph/cypher/
semantic.rs

1//! Semantic analysis and type checking for Cypher queries
2//!
3//! Validates the semantic correctness of parsed Cypher queries including:
4//! - Variable scope checking
5//! - Type compatibility validation
6//! - Aggregation context verification
7//! - Pattern validity
8
9use super::ast::*;
10use std::collections::{HashMap, HashSet};
11use thiserror::Error;
12
13#[derive(Debug, Error)]
14pub enum SemanticError {
15    #[error("Undefined variable: {0}")]
16    UndefinedVariable(String),
17
18    #[error("Variable already defined: {0}")]
19    VariableAlreadyDefined(String),
20
21    #[error("Type mismatch: expected {expected}, found {found}")]
22    TypeMismatch { expected: String, found: String },
23
24    #[error("Aggregation not allowed in {0}")]
25    InvalidAggregation(String),
26
27    #[error("Cannot mix aggregated and non-aggregated expressions")]
28    MixedAggregation,
29
30    #[error("Invalid pattern: {0}")]
31    InvalidPattern(String),
32
33    #[error("Invalid hyperedge: {0}")]
34    InvalidHyperedge(String),
35
36    #[error("Property access on non-object type")]
37    InvalidPropertyAccess,
38
39    #[error(
40        "Invalid number of arguments for function {function}: expected {expected}, found {found}"
41    )]
42    InvalidArgumentCount {
43        function: String,
44        expected: usize,
45        found: usize,
46    },
47}
48
49type SemanticResult<T> = Result<T, SemanticError>;
50
51/// Semantic analyzer for Cypher queries
52pub struct SemanticAnalyzer {
53    scope_stack: Vec<Scope>,
54    in_aggregation: bool,
55}
56
57#[derive(Debug, Clone)]
58struct Scope {
59    variables: HashMap<String, ValueType>,
60}
61
62/// Type system for Cypher values
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub enum ValueType {
65    Integer,
66    Float,
67    String,
68    Boolean,
69    Null,
70    Node,
71    Relationship,
72    Path,
73    List(Box<ValueType>),
74    Map,
75    Any,
76}
77
78impl ValueType {
79    /// Check if this type is compatible with another type
80    pub fn is_compatible_with(&self, other: &ValueType) -> bool {
81        match (self, other) {
82            (ValueType::Any, _) | (_, ValueType::Any) => true,
83            (ValueType::Null, _) | (_, ValueType::Null) => true,
84            (ValueType::Integer, ValueType::Float) | (ValueType::Float, ValueType::Integer) => true,
85            (ValueType::List(a), ValueType::List(b)) => a.is_compatible_with(b),
86            _ => self == other,
87        }
88    }
89
90    /// Check if this is a numeric type
91    pub fn is_numeric(&self) -> bool {
92        matches!(self, ValueType::Integer | ValueType::Float | ValueType::Any)
93    }
94
95    /// Check if this is a graph element (node, relationship, path)
96    pub fn is_graph_element(&self) -> bool {
97        matches!(
98            self,
99            ValueType::Node | ValueType::Relationship | ValueType::Path | ValueType::Any
100        )
101    }
102}
103
104impl Scope {
105    fn new() -> Self {
106        Self {
107            variables: HashMap::new(),
108        }
109    }
110
111    fn define(&mut self, name: String, value_type: ValueType) -> SemanticResult<()> {
112        if self.variables.contains_key(&name) {
113            Err(SemanticError::VariableAlreadyDefined(name))
114        } else {
115            self.variables.insert(name, value_type);
116            Ok(())
117        }
118    }
119
120    fn get(&self, name: &str) -> Option<&ValueType> {
121        self.variables.get(name)
122    }
123}
124
125impl SemanticAnalyzer {
126    pub fn new() -> Self {
127        Self {
128            scope_stack: vec![Scope::new()],
129            in_aggregation: false,
130        }
131    }
132
133    fn current_scope(&self) -> &Scope {
134        self.scope_stack.last().unwrap()
135    }
136
137    fn current_scope_mut(&mut self) -> &mut Scope {
138        self.scope_stack.last_mut().unwrap()
139    }
140
141    fn push_scope(&mut self) {
142        self.scope_stack.push(Scope::new());
143    }
144
145    fn pop_scope(&mut self) {
146        self.scope_stack.pop();
147    }
148
149    fn lookup_variable(&self, name: &str) -> SemanticResult<&ValueType> {
150        for scope in self.scope_stack.iter().rev() {
151            if let Some(value_type) = scope.get(name) {
152                return Ok(value_type);
153            }
154        }
155        Err(SemanticError::UndefinedVariable(name.to_string()))
156    }
157
158    fn define_variable(&mut self, name: String, value_type: ValueType) -> SemanticResult<()> {
159        self.current_scope_mut().define(name, value_type)
160    }
161
162    /// Analyze a complete query
163    pub fn analyze_query(&mut self, query: &Query) -> SemanticResult<()> {
164        for statement in &query.statements {
165            self.analyze_statement(statement)?;
166        }
167        Ok(())
168    }
169
170    fn analyze_statement(&mut self, statement: &Statement) -> SemanticResult<()> {
171        match statement {
172            Statement::Match(clause) => self.analyze_match(clause),
173            Statement::Create(clause) => self.analyze_create(clause),
174            Statement::Merge(clause) => self.analyze_merge(clause),
175            Statement::Delete(clause) => self.analyze_delete(clause),
176            Statement::Set(clause) => self.analyze_set(clause),
177            Statement::Remove(clause) => self.analyze_remove(clause),
178            Statement::Return(clause) => self.analyze_return(clause),
179            Statement::With(clause) => self.analyze_with(clause),
180        }
181    }
182
183    fn analyze_remove(&mut self, clause: &RemoveClause) -> SemanticResult<()> {
184        for item in &clause.items {
185            match item {
186                RemoveItem::Property { variable, .. } => {
187                    // Verify variable is defined
188                    self.lookup_variable(variable)?;
189                }
190                RemoveItem::Labels { variable, .. } => {
191                    // Verify variable is defined
192                    self.lookup_variable(variable)?;
193                }
194            }
195        }
196        Ok(())
197    }
198
199    fn analyze_match(&mut self, clause: &MatchClause) -> SemanticResult<()> {
200        // Analyze patterns and define variables
201        for pattern in &clause.patterns {
202            self.analyze_pattern(pattern)?;
203        }
204
205        // Analyze WHERE clause
206        if let Some(where_clause) = &clause.where_clause {
207            let expr_type = self.analyze_expression(&where_clause.condition)?;
208            if !expr_type.is_compatible_with(&ValueType::Boolean) {
209                return Err(SemanticError::TypeMismatch {
210                    expected: "Boolean".to_string(),
211                    found: format!("{:?}", expr_type),
212                });
213            }
214        }
215
216        Ok(())
217    }
218
219    fn analyze_pattern(&mut self, pattern: &Pattern) -> SemanticResult<()> {
220        match pattern {
221            Pattern::Node(node) => self.analyze_node_pattern(node),
222            Pattern::Relationship(rel) => self.analyze_relationship_pattern(rel),
223            Pattern::Path(path) => self.analyze_path_pattern(path),
224            Pattern::Hyperedge(hyperedge) => self.analyze_hyperedge_pattern(hyperedge),
225        }
226    }
227
228    fn analyze_node_pattern(&mut self, node: &NodePattern) -> SemanticResult<()> {
229        if let Some(variable) = &node.variable {
230            self.define_variable(variable.clone(), ValueType::Node)?;
231        }
232
233        if let Some(properties) = &node.properties {
234            for expr in properties.values() {
235                self.analyze_expression(expr)?;
236            }
237        }
238
239        Ok(())
240    }
241
242    fn analyze_relationship_pattern(&mut self, rel: &RelationshipPattern) -> SemanticResult<()> {
243        self.analyze_node_pattern(&rel.from)?;
244        // rel.to is now a Pattern (can be NodePattern or chained RelationshipPattern)
245        self.analyze_pattern(&*rel.to)?;
246
247        if let Some(variable) = &rel.variable {
248            self.define_variable(variable.clone(), ValueType::Relationship)?;
249        }
250
251        if let Some(properties) = &rel.properties {
252            for expr in properties.values() {
253                self.analyze_expression(expr)?;
254            }
255        }
256
257        // Validate range if present
258        if let Some(range) = &rel.range {
259            if let (Some(min), Some(max)) = (range.min, range.max) {
260                if min > max {
261                    return Err(SemanticError::InvalidPattern(
262                        "Minimum range cannot be greater than maximum".to_string(),
263                    ));
264                }
265            }
266        }
267
268        Ok(())
269    }
270
271    fn analyze_path_pattern(&mut self, path: &PathPattern) -> SemanticResult<()> {
272        self.define_variable(path.variable.clone(), ValueType::Path)?;
273        self.analyze_pattern(&path.pattern)
274    }
275
276    fn analyze_hyperedge_pattern(&mut self, hyperedge: &HyperedgePattern) -> SemanticResult<()> {
277        // Validate hyperedge has at least 2 target nodes
278        if hyperedge.to.len() < 2 {
279            return Err(SemanticError::InvalidHyperedge(
280                "Hyperedge must have at least 2 target nodes".to_string(),
281            ));
282        }
283
284        // Validate arity matches
285        if hyperedge.arity != hyperedge.to.len() + 1 {
286            return Err(SemanticError::InvalidHyperedge(
287                "Hyperedge arity doesn't match number of participating nodes".to_string(),
288            ));
289        }
290
291        self.analyze_node_pattern(&hyperedge.from)?;
292
293        for target in &hyperedge.to {
294            self.analyze_node_pattern(target)?;
295        }
296
297        if let Some(variable) = &hyperedge.variable {
298            self.define_variable(variable.clone(), ValueType::Relationship)?;
299        }
300
301        if let Some(properties) = &hyperedge.properties {
302            for expr in properties.values() {
303                self.analyze_expression(expr)?;
304            }
305        }
306
307        Ok(())
308    }
309
310    fn analyze_create(&mut self, clause: &CreateClause) -> SemanticResult<()> {
311        for pattern in &clause.patterns {
312            self.analyze_pattern(pattern)?;
313        }
314        Ok(())
315    }
316
317    fn analyze_merge(&mut self, clause: &MergeClause) -> SemanticResult<()> {
318        self.analyze_pattern(&clause.pattern)?;
319
320        if let Some(on_create) = &clause.on_create {
321            self.analyze_set(on_create)?;
322        }
323
324        if let Some(on_match) = &clause.on_match {
325            self.analyze_set(on_match)?;
326        }
327
328        Ok(())
329    }
330
331    fn analyze_delete(&mut self, clause: &DeleteClause) -> SemanticResult<()> {
332        for expr in &clause.expressions {
333            let expr_type = self.analyze_expression(expr)?;
334            if !expr_type.is_graph_element() {
335                return Err(SemanticError::TypeMismatch {
336                    expected: "graph element (node, relationship, path)".to_string(),
337                    found: format!("{:?}", expr_type),
338                });
339            }
340        }
341        Ok(())
342    }
343
344    fn analyze_set(&mut self, clause: &SetClause) -> SemanticResult<()> {
345        for item in &clause.items {
346            match item {
347                SetItem::Property {
348                    variable, value, ..
349                } => {
350                    self.lookup_variable(variable)?;
351                    self.analyze_expression(value)?;
352                }
353                SetItem::Variable { variable, value } => {
354                    self.lookup_variable(variable)?;
355                    self.analyze_expression(value)?;
356                }
357                SetItem::Labels { variable, .. } => {
358                    self.lookup_variable(variable)?;
359                }
360            }
361        }
362        Ok(())
363    }
364
365    fn analyze_return(&mut self, clause: &ReturnClause) -> SemanticResult<()> {
366        self.analyze_return_items(&clause.items)?;
367
368        if let Some(order_by) = &clause.order_by {
369            for item in &order_by.items {
370                self.analyze_expression(&item.expression)?;
371            }
372        }
373
374        if let Some(skip) = &clause.skip {
375            let skip_type = self.analyze_expression(skip)?;
376            if !skip_type.is_compatible_with(&ValueType::Integer) {
377                return Err(SemanticError::TypeMismatch {
378                    expected: "Integer".to_string(),
379                    found: format!("{:?}", skip_type),
380                });
381            }
382        }
383
384        if let Some(limit) = &clause.limit {
385            let limit_type = self.analyze_expression(limit)?;
386            if !limit_type.is_compatible_with(&ValueType::Integer) {
387                return Err(SemanticError::TypeMismatch {
388                    expected: "Integer".to_string(),
389                    found: format!("{:?}", limit_type),
390                });
391            }
392        }
393
394        Ok(())
395    }
396
397    fn analyze_with(&mut self, clause: &WithClause) -> SemanticResult<()> {
398        self.analyze_return_items(&clause.items)?;
399
400        if let Some(where_clause) = &clause.where_clause {
401            let expr_type = self.analyze_expression(&where_clause.condition)?;
402            if !expr_type.is_compatible_with(&ValueType::Boolean) {
403                return Err(SemanticError::TypeMismatch {
404                    expected: "Boolean".to_string(),
405                    found: format!("{:?}", expr_type),
406                });
407            }
408        }
409
410        Ok(())
411    }
412
413    fn analyze_return_items(&mut self, items: &[ReturnItem]) -> SemanticResult<()> {
414        let mut has_aggregation = false;
415        let mut has_non_aggregation = false;
416
417        for item in items {
418            let item_has_agg = item.expression.has_aggregation();
419            has_aggregation |= item_has_agg;
420            has_non_aggregation |= !item_has_agg && !item.expression.is_constant();
421        }
422
423        if has_aggregation && has_non_aggregation {
424            return Err(SemanticError::MixedAggregation);
425        }
426
427        for item in items {
428            self.analyze_expression(&item.expression)?;
429        }
430
431        Ok(())
432    }
433
434    fn analyze_expression(&mut self, expr: &Expression) -> SemanticResult<ValueType> {
435        match expr {
436            Expression::Integer(_) => Ok(ValueType::Integer),
437            Expression::Float(_) => Ok(ValueType::Float),
438            Expression::String(_) => Ok(ValueType::String),
439            Expression::Boolean(_) => Ok(ValueType::Boolean),
440            Expression::Null => Ok(ValueType::Null),
441
442            Expression::Variable(name) => {
443                self.lookup_variable(name)?;
444                Ok(ValueType::Any)
445            }
446
447            Expression::Property { object, .. } => {
448                let obj_type = self.analyze_expression(object)?;
449                if !obj_type.is_graph_element()
450                    && obj_type != ValueType::Map
451                    && obj_type != ValueType::Any
452                {
453                    return Err(SemanticError::InvalidPropertyAccess);
454                }
455                Ok(ValueType::Any)
456            }
457
458            Expression::List(items) => {
459                if items.is_empty() {
460                    Ok(ValueType::List(Box::new(ValueType::Any)))
461                } else {
462                    let first_type = self.analyze_expression(&items[0])?;
463                    for item in items.iter().skip(1) {
464                        let item_type = self.analyze_expression(item)?;
465                        if !item_type.is_compatible_with(&first_type) {
466                            return Ok(ValueType::List(Box::new(ValueType::Any)));
467                        }
468                    }
469                    Ok(ValueType::List(Box::new(first_type)))
470                }
471            }
472
473            Expression::Map(map) => {
474                for expr in map.values() {
475                    self.analyze_expression(expr)?;
476                }
477                Ok(ValueType::Map)
478            }
479
480            Expression::BinaryOp { left, op, right } => {
481                let left_type = self.analyze_expression(left)?;
482                let right_type = self.analyze_expression(right)?;
483
484                match op {
485                    BinaryOperator::Add
486                    | BinaryOperator::Subtract
487                    | BinaryOperator::Multiply
488                    | BinaryOperator::Divide
489                    | BinaryOperator::Modulo
490                    | BinaryOperator::Power => {
491                        if !left_type.is_numeric() || !right_type.is_numeric() {
492                            return Err(SemanticError::TypeMismatch {
493                                expected: "numeric".to_string(),
494                                found: format!("{:?} and {:?}", left_type, right_type),
495                            });
496                        }
497                        if left_type == ValueType::Float || right_type == ValueType::Float {
498                            Ok(ValueType::Float)
499                        } else {
500                            Ok(ValueType::Integer)
501                        }
502                    }
503                    BinaryOperator::Equal
504                    | BinaryOperator::NotEqual
505                    | BinaryOperator::LessThan
506                    | BinaryOperator::LessThanOrEqual
507                    | BinaryOperator::GreaterThan
508                    | BinaryOperator::GreaterThanOrEqual => Ok(ValueType::Boolean),
509                    BinaryOperator::And | BinaryOperator::Or | BinaryOperator::Xor => {
510                        Ok(ValueType::Boolean)
511                    }
512                    _ => Ok(ValueType::Any),
513                }
514            }
515
516            Expression::UnaryOp { op, operand } => {
517                let operand_type = self.analyze_expression(operand)?;
518                match op {
519                    UnaryOperator::Not | UnaryOperator::IsNull | UnaryOperator::IsNotNull => {
520                        Ok(ValueType::Boolean)
521                    }
522                    UnaryOperator::Minus | UnaryOperator::Plus => Ok(operand_type),
523                }
524            }
525
526            Expression::FunctionCall { args, .. } => {
527                for arg in args {
528                    self.analyze_expression(arg)?;
529                }
530                Ok(ValueType::Any)
531            }
532
533            Expression::Aggregation { expression, .. } => {
534                let old_in_agg = self.in_aggregation;
535                self.in_aggregation = true;
536                let result = self.analyze_expression(expression);
537                self.in_aggregation = old_in_agg;
538                result?;
539                Ok(ValueType::Any)
540            }
541
542            Expression::PatternPredicate(pattern) => {
543                self.analyze_pattern(pattern)?;
544                Ok(ValueType::Boolean)
545            }
546
547            Expression::Case {
548                expression,
549                alternatives,
550                default,
551            } => {
552                if let Some(expr) = expression {
553                    self.analyze_expression(expr)?;
554                }
555
556                for (condition, result) in alternatives {
557                    self.analyze_expression(condition)?;
558                    self.analyze_expression(result)?;
559                }
560
561                if let Some(default_expr) = default {
562                    self.analyze_expression(default_expr)?;
563                }
564
565                Ok(ValueType::Any)
566            }
567        }
568    }
569}
570
571impl Default for SemanticAnalyzer {
572    fn default() -> Self {
573        Self::new()
574    }
575}
576
577#[cfg(test)]
578mod tests {
579    use super::*;
580    use crate::cypher::parser::parse_cypher;
581
582    #[test]
583    fn test_analyze_simple_match() {
584        let query = parse_cypher("MATCH (n:Person) RETURN n").unwrap();
585        let mut analyzer = SemanticAnalyzer::new();
586        assert!(analyzer.analyze_query(&query).is_ok());
587    }
588
589    #[test]
590    fn test_undefined_variable() {
591        let query = parse_cypher("MATCH (n:Person) RETURN m").unwrap();
592        let mut analyzer = SemanticAnalyzer::new();
593        assert!(matches!(
594            analyzer.analyze_query(&query),
595            Err(SemanticError::UndefinedVariable(_))
596        ));
597    }
598
599    #[test]
600    fn test_mixed_aggregation() {
601        let query = parse_cypher("MATCH (n:Person) RETURN n.name, COUNT(n)").unwrap();
602        let mut analyzer = SemanticAnalyzer::new();
603        assert!(matches!(
604            analyzer.analyze_query(&query),
605            Err(SemanticError::MixedAggregation)
606        ));
607    }
608
609    #[test]
610    #[ignore = "Hyperedge syntax not yet implemented in parser"]
611    fn test_hyperedge_validation() {
612        let query = parse_cypher("MATCH (a)-[r:REL]->(b, c) RETURN a, r, b, c").unwrap();
613        let mut analyzer = SemanticAnalyzer::new();
614        assert!(analyzer.analyze_query(&query).is_ok());
615    }
616}