uni_query/query/rewrite/rule.rs
1/// Rewrite rule trait and supporting types
2use crate::query::rewrite::context::RewriteContext;
3use crate::query::rewrite::error::RewriteError;
4use uni_cypher::ast::Expr;
5
6/// Trait for implementing query rewrite rules
7///
8/// A rewrite rule transforms function calls into equivalent predicate expressions
9/// at compile time. Rules are registered in the global registry and applied
10/// during query compilation.
11///
12/// # Example
13///
14/// ```ignore
15/// impl RewriteRule for ValidAtRule {
16/// fn function_name(&self) -> &str {
17/// "uni.temporal.validAt"
18/// }
19///
20/// fn validate_args(&self, args: &[Expr]) -> Result<(), RewriteError> {
21/// // Check arity and argument types
22/// if args.len() != 4 {
23/// return Err(RewriteError::ArityMismatch { expected: 4, got: args.len() });
24/// }
25/// // ... more validation
26/// Ok(())
27/// }
28///
29/// fn rewrite(&self, args: Vec<Expr>, _ctx: &RewriteContext) -> Result<Expr, RewriteError> {
30/// // Transform into predicate expression
31/// Ok(rewritten_expr)
32/// }
33/// }
34/// ```
35pub trait RewriteRule: Send + Sync {
36 /// The fully-qualified function name this rule matches
37 ///
38 /// Example: "uni.temporal.validAt"
39 fn function_name(&self) -> &str;
40
41 /// Validate arguments before attempting rewrite
42 ///
43 /// Returns `Ok(())` if arguments are valid and the rule can be applied.
44 /// Returns `Err(RewriteError)` if arguments don't match expected pattern.
45 ///
46 /// Common validations:
47 /// - Check arity (number of arguments)
48 /// - Verify certain arguments are string literals (property names)
49 /// - Verify certain arguments are entity references (variables)
50 /// - Check argument types
51 fn validate_args(&self, args: &[Expr]) -> Result<(), RewriteError>;
52
53 /// Perform the actual rewrite transformation
54 ///
55 /// Takes validated arguments and context, returns the rewritten expression.
56 /// This method is only called after `validate_args` returns `Ok(())` and
57 /// `is_applicable` returns `true`.
58 ///
59 /// # Arguments
60 ///
61 /// * `args` - Function arguments (already validated)
62 /// * `ctx` - Rewrite context (scope, schema, etc.)
63 ///
64 /// # Returns
65 ///
66 /// The rewritten expression, or an error if transformation fails.
67 fn rewrite(&self, args: Vec<Expr>, ctx: &RewriteContext) -> Result<Expr, RewriteError>;
68
69 /// Check if this rule can be applied in the current context
70 ///
71 /// Override this method if the rule requires specific context conditions
72 /// (e.g., schema information, variable scope).
73 ///
74 /// Default implementation: always applicable if args validate.
75 fn is_applicable(&self, _ctx: &RewriteContext) -> bool {
76 true
77 }
78}
79
80/// Function argument arity specification
81#[derive(Debug, Clone, PartialEq, Eq)]
82pub enum Arity {
83 /// Exact number of arguments required
84 Exact(usize),
85
86 /// Range of acceptable argument counts (min, max)
87 Range(usize, usize),
88
89 /// Variable number of arguments with minimum count
90 VarArgs(usize),
91}
92
93impl Arity {
94 /// Check if the given argument count satisfies this arity requirement
95 pub fn check(&self, count: usize) -> Result<(), RewriteError> {
96 let (min, max) = match self {
97 Arity::Exact(n) => (*n, *n),
98 Arity::Range(min, max) => (*min, *max),
99 Arity::VarArgs(min) => (*min, usize::MAX),
100 };
101
102 if count >= min && count <= max {
103 return Ok(());
104 }
105
106 if min == max {
107 Err(RewriteError::ArityMismatch {
108 expected: min,
109 got: count,
110 })
111 } else {
112 Err(RewriteError::ArityOutOfRange {
113 min,
114 max,
115 got: count,
116 })
117 }
118 }
119}
120
121/// Metadata about function argument requirements
122///
123/// Use this to declaratively specify argument constraints for a rewrite rule.
124#[derive(Debug, Clone)]
125pub struct ArgConstraints {
126 /// Expected number of arguments (or range)
127 pub arity: Arity,
128
129 /// Indices of arguments that must be string literals (e.g., property names)
130 pub literal_args: Vec<usize>,
131
132 /// Index of the argument that is the entity (for property access)
133 pub entity_arg: Option<usize>,
134}
135
136impl ArgConstraints {
137 /// Validate a set of arguments against these constraints
138 pub fn validate(&self, args: &[Expr]) -> Result<(), RewriteError> {
139 // Check arity
140 self.arity.check(args.len())?;
141
142 // Check literal arguments
143 for &idx in &self.literal_args {
144 if idx >= args.len() {
145 continue; // Arity check already failed
146 }
147
148 if !matches!(args[idx], Expr::Literal(_)) {
149 return Err(RewriteError::ExpectedStringLiteral { arg_index: idx });
150 }
151 }
152
153 // Check entity argument
154 if let Some(idx) = self.entity_arg {
155 if idx >= args.len() {
156 return Ok(()); // Arity check already failed
157 }
158
159 if !matches!(args[idx], Expr::Variable(_) | Expr::Property(_, _)) {
160 return Err(RewriteError::ExpectedEntityReference { arg_index: idx });
161 }
162 }
163
164 Ok(())
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn test_arity_exact() {
174 let arity = Arity::Exact(3);
175 assert!(arity.check(3).is_ok());
176 assert!(arity.check(2).is_err());
177 assert!(arity.check(4).is_err());
178 }
179
180 #[test]
181 fn test_arity_range() {
182 let arity = Arity::Range(2, 4);
183 assert!(arity.check(1).is_err());
184 assert!(arity.check(2).is_ok());
185 assert!(arity.check(3).is_ok());
186 assert!(arity.check(4).is_ok());
187 assert!(arity.check(5).is_err());
188 }
189
190 #[test]
191 fn test_arity_varargs() {
192 let arity = Arity::VarArgs(2);
193 assert!(arity.check(1).is_err());
194 assert!(arity.check(2).is_ok());
195 assert!(arity.check(10).is_ok());
196 }
197
198 #[test]
199 fn test_arg_constraints_validate() {
200 use uni_cypher::ast::CypherLiteral;
201
202 let constraints = ArgConstraints {
203 arity: Arity::Exact(3),
204 literal_args: vec![1],
205 entity_arg: Some(0),
206 };
207
208 // Valid arguments
209 let valid_args = vec![
210 Expr::Variable("e".into()),
211 Expr::Literal(CypherLiteral::String("prop".into())),
212 Expr::Variable("x".into()),
213 ];
214 assert!(constraints.validate(&valid_args).is_ok());
215
216 // Wrong arity
217 let wrong_arity = vec![Expr::Variable("e".into())];
218 assert!(constraints.validate(&wrong_arity).is_err());
219
220 // Non-literal where literal expected
221 let non_literal = vec![
222 Expr::Variable("e".into()),
223 Expr::Variable("prop".into()), // Should be literal
224 Expr::Variable("x".into()),
225 ];
226 assert!(constraints.validate(&non_literal).is_err());
227 }
228}