1use super::ast::*;
10use std::collections::{HashMap, HashSet};
11use thiserror::Error;
12
13#[derive(Debug, Error)]
14pub enum SemanticError {
15 #[error("Undefined variable: {0}")]
16 UndefinedVariable(String),
17
18 #[error("Variable already defined: {0}")]
19 VariableAlreadyDefined(String),
20
21 #[error("Type mismatch: expected {expected}, found {found}")]
22 TypeMismatch { expected: String, found: String },
23
24 #[error("Aggregation not allowed in {0}")]
25 InvalidAggregation(String),
26
27 #[error("Cannot mix aggregated and non-aggregated expressions")]
28 MixedAggregation,
29
30 #[error("Invalid pattern: {0}")]
31 InvalidPattern(String),
32
33 #[error("Invalid hyperedge: {0}")]
34 InvalidHyperedge(String),
35
36 #[error("Property access on non-object type")]
37 InvalidPropertyAccess,
38
39 #[error(
40 "Invalid number of arguments for function {function}: expected {expected}, found {found}"
41 )]
42 InvalidArgumentCount {
43 function: String,
44 expected: usize,
45 found: usize,
46 },
47}
48
49type SemanticResult<T> = Result<T, SemanticError>;
50
51pub struct SemanticAnalyzer {
53 scope_stack: Vec<Scope>,
54 in_aggregation: bool,
55}
56
57#[derive(Debug, Clone)]
58struct Scope {
59 variables: HashMap<String, ValueType>,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq)]
64pub enum ValueType {
65 Integer,
66 Float,
67 String,
68 Boolean,
69 Null,
70 Node,
71 Relationship,
72 Path,
73 List(Box<ValueType>),
74 Map,
75 Any,
76}
77
78impl ValueType {
79 pub fn is_compatible_with(&self, other: &ValueType) -> bool {
81 match (self, other) {
82 (ValueType::Any, _) | (_, ValueType::Any) => true,
83 (ValueType::Null, _) | (_, ValueType::Null) => true,
84 (ValueType::Integer, ValueType::Float) | (ValueType::Float, ValueType::Integer) => true,
85 (ValueType::List(a), ValueType::List(b)) => a.is_compatible_with(b),
86 _ => self == other,
87 }
88 }
89
90 pub fn is_numeric(&self) -> bool {
92 matches!(self, ValueType::Integer | ValueType::Float | ValueType::Any)
93 }
94
95 pub fn is_graph_element(&self) -> bool {
97 matches!(
98 self,
99 ValueType::Node | ValueType::Relationship | ValueType::Path | ValueType::Any
100 )
101 }
102}
103
104impl Scope {
105 fn new() -> Self {
106 Self {
107 variables: HashMap::new(),
108 }
109 }
110
111 fn define(&mut self, name: String, value_type: ValueType) -> SemanticResult<()> {
112 if self.variables.contains_key(&name) {
113 Err(SemanticError::VariableAlreadyDefined(name))
114 } else {
115 self.variables.insert(name, value_type);
116 Ok(())
117 }
118 }
119
120 fn get(&self, name: &str) -> Option<&ValueType> {
121 self.variables.get(name)
122 }
123}
124
125impl SemanticAnalyzer {
126 pub fn new() -> Self {
127 Self {
128 scope_stack: vec![Scope::new()],
129 in_aggregation: false,
130 }
131 }
132
133 fn current_scope(&self) -> &Scope {
134 self.scope_stack.last().unwrap()
135 }
136
137 fn current_scope_mut(&mut self) -> &mut Scope {
138 self.scope_stack.last_mut().unwrap()
139 }
140
141 fn push_scope(&mut self) {
142 self.scope_stack.push(Scope::new());
143 }
144
145 fn pop_scope(&mut self) {
146 self.scope_stack.pop();
147 }
148
149 fn lookup_variable(&self, name: &str) -> SemanticResult<&ValueType> {
150 for scope in self.scope_stack.iter().rev() {
151 if let Some(value_type) = scope.get(name) {
152 return Ok(value_type);
153 }
154 }
155 Err(SemanticError::UndefinedVariable(name.to_string()))
156 }
157
158 fn define_variable(&mut self, name: String, value_type: ValueType) -> SemanticResult<()> {
159 self.current_scope_mut().define(name, value_type)
160 }
161
162 pub fn analyze_query(&mut self, query: &Query) -> SemanticResult<()> {
164 for statement in &query.statements {
165 self.analyze_statement(statement)?;
166 }
167 Ok(())
168 }
169
170 fn analyze_statement(&mut self, statement: &Statement) -> SemanticResult<()> {
171 match statement {
172 Statement::Match(clause) => self.analyze_match(clause),
173 Statement::Create(clause) => self.analyze_create(clause),
174 Statement::Merge(clause) => self.analyze_merge(clause),
175 Statement::Delete(clause) => self.analyze_delete(clause),
176 Statement::Set(clause) => self.analyze_set(clause),
177 Statement::Remove(clause) => self.analyze_remove(clause),
178 Statement::Return(clause) => self.analyze_return(clause),
179 Statement::With(clause) => self.analyze_with(clause),
180 }
181 }
182
183 fn analyze_remove(&mut self, clause: &RemoveClause) -> SemanticResult<()> {
184 for item in &clause.items {
185 match item {
186 RemoveItem::Property { variable, .. } => {
187 self.lookup_variable(variable)?;
189 }
190 RemoveItem::Labels { variable, .. } => {
191 self.lookup_variable(variable)?;
193 }
194 }
195 }
196 Ok(())
197 }
198
199 fn analyze_match(&mut self, clause: &MatchClause) -> SemanticResult<()> {
200 for pattern in &clause.patterns {
202 self.analyze_pattern(pattern)?;
203 }
204
205 if let Some(where_clause) = &clause.where_clause {
207 let expr_type = self.analyze_expression(&where_clause.condition)?;
208 if !expr_type.is_compatible_with(&ValueType::Boolean) {
209 return Err(SemanticError::TypeMismatch {
210 expected: "Boolean".to_string(),
211 found: format!("{:?}", expr_type),
212 });
213 }
214 }
215
216 Ok(())
217 }
218
219 fn analyze_pattern(&mut self, pattern: &Pattern) -> SemanticResult<()> {
220 match pattern {
221 Pattern::Node(node) => self.analyze_node_pattern(node),
222 Pattern::Relationship(rel) => self.analyze_relationship_pattern(rel),
223 Pattern::Path(path) => self.analyze_path_pattern(path),
224 Pattern::Hyperedge(hyperedge) => self.analyze_hyperedge_pattern(hyperedge),
225 }
226 }
227
228 fn analyze_node_pattern(&mut self, node: &NodePattern) -> SemanticResult<()> {
229 if let Some(variable) = &node.variable {
230 self.define_variable(variable.clone(), ValueType::Node)?;
231 }
232
233 if let Some(properties) = &node.properties {
234 for expr in properties.values() {
235 self.analyze_expression(expr)?;
236 }
237 }
238
239 Ok(())
240 }
241
242 fn analyze_relationship_pattern(&mut self, rel: &RelationshipPattern) -> SemanticResult<()> {
243 self.analyze_node_pattern(&rel.from)?;
244 self.analyze_pattern(&*rel.to)?;
246
247 if let Some(variable) = &rel.variable {
248 self.define_variable(variable.clone(), ValueType::Relationship)?;
249 }
250
251 if let Some(properties) = &rel.properties {
252 for expr in properties.values() {
253 self.analyze_expression(expr)?;
254 }
255 }
256
257 if let Some(range) = &rel.range {
259 if let (Some(min), Some(max)) = (range.min, range.max) {
260 if min > max {
261 return Err(SemanticError::InvalidPattern(
262 "Minimum range cannot be greater than maximum".to_string(),
263 ));
264 }
265 }
266 }
267
268 Ok(())
269 }
270
271 fn analyze_path_pattern(&mut self, path: &PathPattern) -> SemanticResult<()> {
272 self.define_variable(path.variable.clone(), ValueType::Path)?;
273 self.analyze_pattern(&path.pattern)
274 }
275
276 fn analyze_hyperedge_pattern(&mut self, hyperedge: &HyperedgePattern) -> SemanticResult<()> {
277 if hyperedge.to.len() < 2 {
279 return Err(SemanticError::InvalidHyperedge(
280 "Hyperedge must have at least 2 target nodes".to_string(),
281 ));
282 }
283
284 if hyperedge.arity != hyperedge.to.len() + 1 {
286 return Err(SemanticError::InvalidHyperedge(
287 "Hyperedge arity doesn't match number of participating nodes".to_string(),
288 ));
289 }
290
291 self.analyze_node_pattern(&hyperedge.from)?;
292
293 for target in &hyperedge.to {
294 self.analyze_node_pattern(target)?;
295 }
296
297 if let Some(variable) = &hyperedge.variable {
298 self.define_variable(variable.clone(), ValueType::Relationship)?;
299 }
300
301 if let Some(properties) = &hyperedge.properties {
302 for expr in properties.values() {
303 self.analyze_expression(expr)?;
304 }
305 }
306
307 Ok(())
308 }
309
310 fn analyze_create(&mut self, clause: &CreateClause) -> SemanticResult<()> {
311 for pattern in &clause.patterns {
312 self.analyze_pattern(pattern)?;
313 }
314 Ok(())
315 }
316
317 fn analyze_merge(&mut self, clause: &MergeClause) -> SemanticResult<()> {
318 self.analyze_pattern(&clause.pattern)?;
319
320 if let Some(on_create) = &clause.on_create {
321 self.analyze_set(on_create)?;
322 }
323
324 if let Some(on_match) = &clause.on_match {
325 self.analyze_set(on_match)?;
326 }
327
328 Ok(())
329 }
330
331 fn analyze_delete(&mut self, clause: &DeleteClause) -> SemanticResult<()> {
332 for expr in &clause.expressions {
333 let expr_type = self.analyze_expression(expr)?;
334 if !expr_type.is_graph_element() {
335 return Err(SemanticError::TypeMismatch {
336 expected: "graph element (node, relationship, path)".to_string(),
337 found: format!("{:?}", expr_type),
338 });
339 }
340 }
341 Ok(())
342 }
343
344 fn analyze_set(&mut self, clause: &SetClause) -> SemanticResult<()> {
345 for item in &clause.items {
346 match item {
347 SetItem::Property {
348 variable, value, ..
349 } => {
350 self.lookup_variable(variable)?;
351 self.analyze_expression(value)?;
352 }
353 SetItem::Variable { variable, value } => {
354 self.lookup_variable(variable)?;
355 self.analyze_expression(value)?;
356 }
357 SetItem::Labels { variable, .. } => {
358 self.lookup_variable(variable)?;
359 }
360 }
361 }
362 Ok(())
363 }
364
365 fn analyze_return(&mut self, clause: &ReturnClause) -> SemanticResult<()> {
366 self.analyze_return_items(&clause.items)?;
367
368 if let Some(order_by) = &clause.order_by {
369 for item in &order_by.items {
370 self.analyze_expression(&item.expression)?;
371 }
372 }
373
374 if let Some(skip) = &clause.skip {
375 let skip_type = self.analyze_expression(skip)?;
376 if !skip_type.is_compatible_with(&ValueType::Integer) {
377 return Err(SemanticError::TypeMismatch {
378 expected: "Integer".to_string(),
379 found: format!("{:?}", skip_type),
380 });
381 }
382 }
383
384 if let Some(limit) = &clause.limit {
385 let limit_type = self.analyze_expression(limit)?;
386 if !limit_type.is_compatible_with(&ValueType::Integer) {
387 return Err(SemanticError::TypeMismatch {
388 expected: "Integer".to_string(),
389 found: format!("{:?}", limit_type),
390 });
391 }
392 }
393
394 Ok(())
395 }
396
397 fn analyze_with(&mut self, clause: &WithClause) -> SemanticResult<()> {
398 self.analyze_return_items(&clause.items)?;
399
400 if let Some(where_clause) = &clause.where_clause {
401 let expr_type = self.analyze_expression(&where_clause.condition)?;
402 if !expr_type.is_compatible_with(&ValueType::Boolean) {
403 return Err(SemanticError::TypeMismatch {
404 expected: "Boolean".to_string(),
405 found: format!("{:?}", expr_type),
406 });
407 }
408 }
409
410 Ok(())
411 }
412
413 fn analyze_return_items(&mut self, items: &[ReturnItem]) -> SemanticResult<()> {
414 let mut has_aggregation = false;
415 let mut has_non_aggregation = false;
416
417 for item in items {
418 let item_has_agg = item.expression.has_aggregation();
419 has_aggregation |= item_has_agg;
420 has_non_aggregation |= !item_has_agg && !item.expression.is_constant();
421 }
422
423 if has_aggregation && has_non_aggregation {
424 return Err(SemanticError::MixedAggregation);
425 }
426
427 for item in items {
428 self.analyze_expression(&item.expression)?;
429 }
430
431 Ok(())
432 }
433
434 fn analyze_expression(&mut self, expr: &Expression) -> SemanticResult<ValueType> {
435 match expr {
436 Expression::Integer(_) => Ok(ValueType::Integer),
437 Expression::Float(_) => Ok(ValueType::Float),
438 Expression::String(_) => Ok(ValueType::String),
439 Expression::Boolean(_) => Ok(ValueType::Boolean),
440 Expression::Null => Ok(ValueType::Null),
441
442 Expression::Variable(name) => {
443 self.lookup_variable(name)?;
444 Ok(ValueType::Any)
445 }
446
447 Expression::Property { object, .. } => {
448 let obj_type = self.analyze_expression(object)?;
449 if !obj_type.is_graph_element()
450 && obj_type != ValueType::Map
451 && obj_type != ValueType::Any
452 {
453 return Err(SemanticError::InvalidPropertyAccess);
454 }
455 Ok(ValueType::Any)
456 }
457
458 Expression::List(items) => {
459 if items.is_empty() {
460 Ok(ValueType::List(Box::new(ValueType::Any)))
461 } else {
462 let first_type = self.analyze_expression(&items[0])?;
463 for item in items.iter().skip(1) {
464 let item_type = self.analyze_expression(item)?;
465 if !item_type.is_compatible_with(&first_type) {
466 return Ok(ValueType::List(Box::new(ValueType::Any)));
467 }
468 }
469 Ok(ValueType::List(Box::new(first_type)))
470 }
471 }
472
473 Expression::Map(map) => {
474 for expr in map.values() {
475 self.analyze_expression(expr)?;
476 }
477 Ok(ValueType::Map)
478 }
479
480 Expression::BinaryOp { left, op, right } => {
481 let left_type = self.analyze_expression(left)?;
482 let right_type = self.analyze_expression(right)?;
483
484 match op {
485 BinaryOperator::Add
486 | BinaryOperator::Subtract
487 | BinaryOperator::Multiply
488 | BinaryOperator::Divide
489 | BinaryOperator::Modulo
490 | BinaryOperator::Power => {
491 if !left_type.is_numeric() || !right_type.is_numeric() {
492 return Err(SemanticError::TypeMismatch {
493 expected: "numeric".to_string(),
494 found: format!("{:?} and {:?}", left_type, right_type),
495 });
496 }
497 if left_type == ValueType::Float || right_type == ValueType::Float {
498 Ok(ValueType::Float)
499 } else {
500 Ok(ValueType::Integer)
501 }
502 }
503 BinaryOperator::Equal
504 | BinaryOperator::NotEqual
505 | BinaryOperator::LessThan
506 | BinaryOperator::LessThanOrEqual
507 | BinaryOperator::GreaterThan
508 | BinaryOperator::GreaterThanOrEqual => Ok(ValueType::Boolean),
509 BinaryOperator::And | BinaryOperator::Or | BinaryOperator::Xor => {
510 Ok(ValueType::Boolean)
511 }
512 _ => Ok(ValueType::Any),
513 }
514 }
515
516 Expression::UnaryOp { op, operand } => {
517 let operand_type = self.analyze_expression(operand)?;
518 match op {
519 UnaryOperator::Not | UnaryOperator::IsNull | UnaryOperator::IsNotNull => {
520 Ok(ValueType::Boolean)
521 }
522 UnaryOperator::Minus | UnaryOperator::Plus => Ok(operand_type),
523 }
524 }
525
526 Expression::FunctionCall { args, .. } => {
527 for arg in args {
528 self.analyze_expression(arg)?;
529 }
530 Ok(ValueType::Any)
531 }
532
533 Expression::Aggregation { expression, .. } => {
534 let old_in_agg = self.in_aggregation;
535 self.in_aggregation = true;
536 let result = self.analyze_expression(expression);
537 self.in_aggregation = old_in_agg;
538 result?;
539 Ok(ValueType::Any)
540 }
541
542 Expression::PatternPredicate(pattern) => {
543 self.analyze_pattern(pattern)?;
544 Ok(ValueType::Boolean)
545 }
546
547 Expression::Case {
548 expression,
549 alternatives,
550 default,
551 } => {
552 if let Some(expr) = expression {
553 self.analyze_expression(expr)?;
554 }
555
556 for (condition, result) in alternatives {
557 self.analyze_expression(condition)?;
558 self.analyze_expression(result)?;
559 }
560
561 if let Some(default_expr) = default {
562 self.analyze_expression(default_expr)?;
563 }
564
565 Ok(ValueType::Any)
566 }
567 }
568 }
569}
570
571impl Default for SemanticAnalyzer {
572 fn default() -> Self {
573 Self::new()
574 }
575}
576
577#[cfg(test)]
578mod tests {
579 use super::*;
580 use crate::cypher::parser::parse_cypher;
581
582 #[test]
583 fn test_analyze_simple_match() {
584 let query = parse_cypher("MATCH (n:Person) RETURN n").unwrap();
585 let mut analyzer = SemanticAnalyzer::new();
586 assert!(analyzer.analyze_query(&query).is_ok());
587 }
588
589 #[test]
590 fn test_undefined_variable() {
591 let query = parse_cypher("MATCH (n:Person) RETURN m").unwrap();
592 let mut analyzer = SemanticAnalyzer::new();
593 assert!(matches!(
594 analyzer.analyze_query(&query),
595 Err(SemanticError::UndefinedVariable(_))
596 ));
597 }
598
599 #[test]
600 fn test_mixed_aggregation() {
601 let query = parse_cypher("MATCH (n:Person) RETURN n.name, COUNT(n)").unwrap();
602 let mut analyzer = SemanticAnalyzer::new();
603 assert!(matches!(
604 analyzer.analyze_query(&query),
605 Err(SemanticError::MixedAggregation)
606 ));
607 }
608
609 #[test]
610 #[ignore = "Hyperedge syntax not yet implemented in parser"]
611 fn test_hyperedge_validation() {
612 let query = parse_cypher("MATCH (a)-[r:REL]->(b, c) RETURN a, r, b, c").unwrap();
613 let mut analyzer = SemanticAnalyzer::new();
614 assert!(analyzer.analyze_query(&query).is_ok());
615 }
616}