Skip to main content

uni_query/query/rewrite/
rule.rs

1/// Rewrite rule trait and supporting types
2use crate::query::rewrite::context::RewriteContext;
3use crate::query::rewrite::error::RewriteError;
4use uni_cypher::ast::Expr;
5
6/// Trait for implementing query rewrite rules
7///
8/// A rewrite rule transforms function calls into equivalent predicate expressions
9/// at compile time. Rules are registered in the global registry and applied
10/// during query compilation.
11///
12/// # Example
13///
14/// ```ignore
15/// impl RewriteRule for ValidAtRule {
16///     fn function_name(&self) -> &str {
17///         "uni.temporal.validAt"
18///     }
19///
20///     fn validate_args(&self, args: &[Expr]) -> Result<(), RewriteError> {
21///         // Check arity and argument types
22///         if args.len() != 4 {
23///             return Err(RewriteError::ArityMismatch { expected: 4, got: args.len() });
24///         }
25///         // ... more validation
26///         Ok(())
27///     }
28///
29///     fn rewrite(&self, args: Vec<Expr>, _ctx: &RewriteContext) -> Result<Expr, RewriteError> {
30///         // Transform into predicate expression
31///         Ok(rewritten_expr)
32///     }
33/// }
34/// ```
35pub trait RewriteRule: Send + Sync {
36    /// The fully-qualified function name this rule matches
37    ///
38    /// Example: "uni.temporal.validAt"
39    fn function_name(&self) -> &str;
40
41    /// Validate arguments before attempting rewrite
42    ///
43    /// Returns `Ok(())` if arguments are valid and the rule can be applied.
44    /// Returns `Err(RewriteError)` if arguments don't match expected pattern.
45    ///
46    /// Common validations:
47    /// - Check arity (number of arguments)
48    /// - Verify certain arguments are string literals (property names)
49    /// - Verify certain arguments are entity references (variables)
50    /// - Check argument types
51    fn validate_args(&self, args: &[Expr]) -> Result<(), RewriteError>;
52
53    /// Perform the actual rewrite transformation
54    ///
55    /// Takes validated arguments and context, returns the rewritten expression.
56    /// This method is only called after `validate_args` returns `Ok(())` and
57    /// `is_applicable` returns `true`.
58    ///
59    /// # Arguments
60    ///
61    /// * `args` - Function arguments (already validated)
62    /// * `ctx` - Rewrite context (scope, schema, etc.)
63    ///
64    /// # Returns
65    ///
66    /// The rewritten expression, or an error if transformation fails.
67    fn rewrite(&self, args: Vec<Expr>, ctx: &RewriteContext) -> Result<Expr, RewriteError>;
68
69    /// Check if this rule can be applied in the current context
70    ///
71    /// Override this method if the rule requires specific context conditions
72    /// (e.g., schema information, variable scope).
73    ///
74    /// Default implementation: always applicable if args validate.
75    fn is_applicable(&self, _ctx: &RewriteContext) -> bool {
76        true
77    }
78}
79
80/// Function argument arity specification
81#[derive(Debug, Clone, PartialEq, Eq)]
82pub enum Arity {
83    /// Exact number of arguments required
84    Exact(usize),
85
86    /// Range of acceptable argument counts (min, max)
87    Range(usize, usize),
88
89    /// Variable number of arguments with minimum count
90    VarArgs(usize),
91}
92
93impl Arity {
94    /// Check if the given argument count satisfies this arity requirement
95    pub fn check(&self, count: usize) -> Result<(), RewriteError> {
96        let (min, max) = match self {
97            Arity::Exact(n) => (*n, *n),
98            Arity::Range(min, max) => (*min, *max),
99            Arity::VarArgs(min) => (*min, usize::MAX),
100        };
101
102        if count >= min && count <= max {
103            return Ok(());
104        }
105
106        if min == max {
107            Err(RewriteError::ArityMismatch {
108                expected: min,
109                got: count,
110            })
111        } else {
112            Err(RewriteError::ArityOutOfRange {
113                min,
114                max,
115                got: count,
116            })
117        }
118    }
119}
120
121/// Metadata about function argument requirements
122///
123/// Use this to declaratively specify argument constraints for a rewrite rule.
124#[derive(Debug, Clone)]
125pub struct ArgConstraints {
126    /// Expected number of arguments (or range)
127    pub arity: Arity,
128
129    /// Indices of arguments that must be string literals (e.g., property names)
130    pub literal_args: Vec<usize>,
131
132    /// Index of the argument that is the entity (for property access)
133    pub entity_arg: Option<usize>,
134}
135
136impl ArgConstraints {
137    /// Validate a set of arguments against these constraints
138    pub fn validate(&self, args: &[Expr]) -> Result<(), RewriteError> {
139        // Check arity
140        self.arity.check(args.len())?;
141
142        // Check literal arguments
143        for &idx in &self.literal_args {
144            if idx >= args.len() {
145                continue; // Arity check already failed
146            }
147
148            if !matches!(args[idx], Expr::Literal(_)) {
149                return Err(RewriteError::ExpectedStringLiteral { arg_index: idx });
150            }
151        }
152
153        // Check entity argument
154        if let Some(idx) = self.entity_arg {
155            if idx >= args.len() {
156                return Ok(()); // Arity check already failed
157            }
158
159            if !matches!(args[idx], Expr::Variable(_) | Expr::Property(_, _)) {
160                return Err(RewriteError::ExpectedEntityReference { arg_index: idx });
161            }
162        }
163
164        Ok(())
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn test_arity_exact() {
174        let arity = Arity::Exact(3);
175        assert!(arity.check(3).is_ok());
176        assert!(arity.check(2).is_err());
177        assert!(arity.check(4).is_err());
178    }
179
180    #[test]
181    fn test_arity_range() {
182        let arity = Arity::Range(2, 4);
183        assert!(arity.check(1).is_err());
184        assert!(arity.check(2).is_ok());
185        assert!(arity.check(3).is_ok());
186        assert!(arity.check(4).is_ok());
187        assert!(arity.check(5).is_err());
188    }
189
190    #[test]
191    fn test_arity_varargs() {
192        let arity = Arity::VarArgs(2);
193        assert!(arity.check(1).is_err());
194        assert!(arity.check(2).is_ok());
195        assert!(arity.check(10).is_ok());
196    }
197
198    #[test]
199    fn test_arg_constraints_validate() {
200        use uni_cypher::ast::CypherLiteral;
201
202        let constraints = ArgConstraints {
203            arity: Arity::Exact(3),
204            literal_args: vec![1],
205            entity_arg: Some(0),
206        };
207
208        // Valid arguments
209        let valid_args = vec![
210            Expr::Variable("e".into()),
211            Expr::Literal(CypherLiteral::String("prop".into())),
212            Expr::Variable("x".into()),
213        ];
214        assert!(constraints.validate(&valid_args).is_ok());
215
216        // Wrong arity
217        let wrong_arity = vec![Expr::Variable("e".into())];
218        assert!(constraints.validate(&wrong_arity).is_err());
219
220        // Non-literal where literal expected
221        let non_literal = vec![
222            Expr::Variable("e".into()),
223            Expr::Variable("prop".into()), // Should be literal
224            Expr::Variable("x".into()),
225        ];
226        assert!(constraints.validate(&non_literal).is_err());
227    }
228}