uni_query/query/rewrite/mod.rs
1/// Query rewriting framework
2///
3/// This module provides a general-purpose framework for transforming function calls
4/// into equivalent predicate expressions at compile time. The framework enables:
5///
6/// - Full predicate pushdown to storage
7/// - Index utilization
8/// - Extensible plugin architecture for adding new rewrite rules
9///
10/// # Architecture
11///
12/// The framework consists of:
13///
14/// - **RewriteRule trait**: Interface for implementing rewrite transformations
15/// - **RewriteRegistry**: Global registry of all rewrite rules
16/// - **ExpressionWalker**: Traverses expression trees and applies rules
17/// - **RewriteContext**: Contextual information during rewriting
18///
19/// # Example Usage
20///
21/// ```ignore
22/// use uni_query::rewrite::{rewrite_statement, get_stats};
23///
24/// // Rewrite a complete query
25/// let rewritten_stmt = rewrite_statement(stmt)?;
26///
27/// // Get statistics
28/// let stats = get_stats();
29/// println!("Rewrites applied: {}", stats.functions_rewritten);
30/// ```
31///
32/// # Adding New Rules
33///
34/// See `rules/README.md` for a guide on implementing custom rewrite rules.
35pub mod context;
36pub mod error;
37pub mod registry;
38pub mod rule;
39pub mod rules;
40pub mod walker;
41
42use context::{RewriteContext, RewriteStats};
43use error::RewriteError;
44use registry::RewriteRegistry;
45use walker::ExpressionWalker;
46
47use std::sync::OnceLock;
48use uni_cypher::ast::{Expr, Statement};
49
50/// Global registry of rewrite rules, initialized once on first use
51static GLOBAL_REGISTRY: OnceLock<RewriteRegistry> = OnceLock::new();
52
53/// Get the global rewrite registry, initializing it if needed
54fn get_or_init_registry() -> &'static RewriteRegistry {
55 GLOBAL_REGISTRY.get_or_init(|| {
56 tracing::info!("Initializing query rewrite framework");
57 RewriteRegistry::with_builtin_rules()
58 })
59}
60
61/// Log rewrite statistics if any functions were visited
62fn log_rewrite_stats(stats: &RewriteStats) {
63 if stats.functions_visited > 0 {
64 tracing::info!(
65 "Rewrite pass complete: {} functions visited, {} rewritten, {} skipped",
66 stats.functions_visited,
67 stats.functions_rewritten,
68 stats.functions_skipped
69 );
70
71 if !stats.errors.is_empty() {
72 tracing::debug!("Rewrite errors: {:?}", stats.errors);
73 }
74 }
75}
76
77/// Rewrite a complete query
78///
79/// This is the main entry point for applying query rewrites. It walks the
80/// entire query tree and applies registered rewrite rules to all function calls.
81///
82/// # Arguments
83///
84/// * `query` - The query to rewrite
85///
86/// # Returns
87///
88/// The rewritten query with function calls transformed into predicates.
89///
90/// # Example
91///
92/// ```ignore
93/// let query = parse_cypher("MATCH (p)-[e:EMPLOYED_BY]->(c) WHERE uni.temporal.validAt(e, 'start', 'end', datetime('2021-06-15')) RETURN c")?;
94/// let rewritten = rewrite_query(query)?;
95/// // The validAt function will be transformed into: e.start <= ... AND (e.end IS NULL OR e.end >= ...)
96/// ```
97pub fn rewrite_query(
98 query: uni_cypher::ast::Query,
99) -> Result<uni_cypher::ast::Query, RewriteError> {
100 let registry = get_or_init_registry();
101 let context = RewriteContext::default();
102
103 let mut walker = ExpressionWalker::new(registry, context);
104 let rewritten_query = walker.rewrite_query(query);
105
106 log_rewrite_stats(&walker.context().stats);
107
108 Ok(rewritten_query)
109}
110
111/// Rewrite a complete statement
112///
113/// This is a convenience function for rewriting single statements.
114///
115/// # Arguments
116///
117/// * `stmt` - The statement to rewrite
118///
119/// # Returns
120///
121/// The rewritten statement with function calls transformed into predicates.
122pub fn rewrite_statement(stmt: Statement) -> Result<Statement, RewriteError> {
123 let registry = get_or_init_registry();
124 let context = RewriteContext::default();
125
126 let mut walker = ExpressionWalker::new(registry, context);
127 let rewritten_stmt = walker.rewrite_statement(stmt);
128
129 log_rewrite_stats(&walker.context().stats);
130
131 Ok(rewritten_stmt)
132}
133
134/// Rewrite a single expression (for testing/debugging)
135///
136/// This is useful for unit testing rewrite rules or debugging transformations
137/// in isolation.
138///
139/// # Arguments
140///
141/// * `expr` - The expression to rewrite
142///
143/// # Returns
144///
145/// The rewritten expression with function calls transformed into predicates.
146pub fn rewrite_expr(expr: Expr) -> Result<Expr, RewriteError> {
147 let registry = get_or_init_registry();
148 let context = RewriteContext::default();
149
150 let mut walker = ExpressionWalker::new(registry, context);
151 Ok(walker.rewrite_expr(expr))
152}
153
154/// Rewrite an expression with custom context
155///
156/// This allows providing custom configuration and tracking statistics.
157///
158/// # Arguments
159///
160/// * `expr` - The expression to rewrite
161/// * `context` - The rewrite context with configuration
162///
163/// # Returns
164///
165/// A tuple of (rewritten expression, updated context with statistics)
166pub fn rewrite_expr_with_context(
167 expr: Expr,
168 context: RewriteContext,
169) -> Result<(Expr, RewriteContext), RewriteError> {
170 let registry = get_or_init_registry();
171
172 let mut walker = ExpressionWalker::new(registry, context);
173 let rewritten_expr = walker.rewrite_expr(expr);
174 let final_context = walker.into_context();
175
176 Ok((rewritten_expr, final_context))
177}
178
179/// Get rewrite statistics from the global registry
180///
181/// This provides observability into the rewriting process, useful for
182/// debugging and performance analysis.
183///
184/// # Returns
185///
186/// Statistics about rewrites performed (empty if no rewrites have run yet)
187pub fn get_stats() -> RewriteStats {
188 // Note: Statistics are per-walker, not global
189 // This function returns empty stats - statistics should be retrieved
190 // from the context after rewriting
191 RewriteStats::default()
192}
193
194/// Check if a function has a registered rewrite rule
195///
196/// This is useful for testing and introspection.
197///
198/// # Arguments
199///
200/// * `function_name` - The fully-qualified function name (e.g., "uni.temporal.validAt")
201///
202/// # Returns
203///
204/// `true` if a rewrite rule is registered for this function
205pub fn has_rewrite_rule(function_name: &str) -> bool {
206 let registry = get_or_init_registry();
207 registry.has_rule(function_name)
208}
209
210/// Get all registered function names
211///
212/// This is useful for testing and introspection.
213///
214/// # Returns
215///
216/// A list of all function names that have registered rewrite rules
217pub fn registered_functions() -> Vec<String> {
218 let registry = get_or_init_registry();
219 registry.registered_functions()
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use uni_cypher::ast::CypherLiteral;
226
227 #[test]
228 fn test_rewrite_expr_basic() {
229 // Test that we can rewrite an expression
230 let expr = Expr::Literal(CypherLiteral::Integer(42));
231 let result = rewrite_expr(expr.clone()).unwrap();
232
233 // Literals should pass through unchanged
234 assert_eq!(result, expr);
235 }
236
237 #[test]
238 fn test_has_rewrite_rule() {
239 // Temporal rules should be registered
240 assert!(has_rewrite_rule("uni.temporal.validAt"));
241 assert!(has_rewrite_rule("uni.temporal.overlaps"));
242 assert!(has_rewrite_rule("uni.temporal.isOngoing"));
243
244 // Non-existent function should return false
245 assert!(!has_rewrite_rule("nonexistent.function"));
246 }
247
248 #[test]
249 fn test_registered_functions() {
250 let functions = registered_functions();
251
252 // Should have at least the temporal functions
253 assert!(functions.len() >= 3);
254 assert!(functions.contains(&"uni.temporal.validAt".to_string()));
255 assert!(functions.contains(&"uni.temporal.overlaps".to_string()));
256 }
257
258 #[test]
259 fn test_rewrite_with_context() {
260 use context::RewriteConfig;
261
262 let expr = Expr::Literal(CypherLiteral::Integer(42));
263 let config = RewriteConfig::default().with_verbose_logging();
264 let context = RewriteContext::with_config(config);
265
266 let (result, final_context) = rewrite_expr_with_context(expr.clone(), context).unwrap();
267
268 assert_eq!(result, expr);
269 assert_eq!(final_context.stats.functions_visited, 0); // No functions in literal
270 }
271}