Skip to main content

uni_query_functions/rewrite/
registry.rs

1/// Global registry for rewrite rules
2use crate::rewrite::rule::RewriteRule;
3use std::collections::HashMap;
4use std::sync::Arc;
5
6/// Registry for all rewrite rules
7///
8/// The registry maintains a map from function names to their rewrite rules.
9/// It is initialized once at startup with all built-in rules and can be
10/// extended with custom rules.
11pub struct RewriteRegistry {
12    /// Map from function name to rewrite rule
13    rules: HashMap<String, Arc<dyn RewriteRule>>,
14}
15
16impl RewriteRegistry {
17    /// Create a new empty registry
18    pub fn new() -> Self {
19        Self {
20            rules: HashMap::new(),
21        }
22    }
23
24    /// Create a registry with all built-in rules registered
25    pub fn with_builtin_rules() -> Self {
26        let mut registry = Self::new();
27        crate::rewrite::rules::register_builtin_rules(&mut registry);
28        registry
29    }
30
31    /// Register a new rewrite rule
32    ///
33    /// If a rule with the same function name already exists, it will be replaced.
34    pub fn register(&mut self, rule: Arc<dyn RewriteRule>) {
35        let function_name = rule.function_name().to_string();
36        tracing::debug!("Registering rewrite rule: {}", function_name);
37        self.rules.insert(function_name, rule);
38    }
39
40    /// Get the rewrite rule for a function name
41    pub fn get_rule(&self, function_name: &str) -> Option<&dyn RewriteRule> {
42        self.rules.get(function_name).map(|r| r.as_ref())
43    }
44
45    /// Check if a function has a rewrite rule
46    pub fn has_rule(&self, function_name: &str) -> bool {
47        self.rules.contains_key(function_name)
48    }
49
50    /// Get all registered function names
51    pub fn registered_functions(&self) -> Vec<String> {
52        self.rules.keys().cloned().collect()
53    }
54
55    /// Get the number of registered rules
56    pub fn len(&self) -> usize {
57        self.rules.len()
58    }
59
60    /// Check if the registry is empty
61    pub fn is_empty(&self) -> bool {
62        self.rules.is_empty()
63    }
64}
65
66impl Default for RewriteRegistry {
67    fn default() -> Self {
68        Self::with_builtin_rules()
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75    use crate::rewrite::context::RewriteContext;
76    use crate::rewrite::error::RewriteError;
77    use uni_cypher::ast::{CypherLiteral, Expr};
78
79    /// Dummy rule for testing
80    struct DummyRule {
81        name: String,
82    }
83
84    impl DummyRule {
85        fn new(name: &str) -> Self {
86            Self {
87                name: name.to_string(),
88            }
89        }
90    }
91
92    impl RewriteRule for DummyRule {
93        fn function_name(&self) -> &str {
94            &self.name
95        }
96
97        fn validate_args(&self, _args: &[Expr]) -> Result<(), RewriteError> {
98            Ok(())
99        }
100
101        fn rewrite(&self, args: Vec<Expr>, _ctx: &RewriteContext) -> Result<Expr, RewriteError> {
102            // Just return the first argument unchanged
103            Ok(args
104                .into_iter()
105                .next()
106                .unwrap_or(Expr::Literal(CypherLiteral::Null)))
107        }
108    }
109
110    #[test]
111    fn test_registry_register_and_lookup() {
112        let mut registry = RewriteRegistry::new();
113
114        let rule = Arc::new(DummyRule::new("test.func"));
115        registry.register(rule);
116
117        assert!(registry.has_rule("test.func"));
118        assert!(!registry.has_rule("nonexistent"));
119
120        let retrieved = registry.get_rule("test.func");
121        assert!(retrieved.is_some());
122        assert_eq!(retrieved.unwrap().function_name(), "test.func");
123    }
124
125    #[test]
126    fn test_registry_replacement() {
127        let mut registry = RewriteRegistry::new();
128
129        registry.register(Arc::new(DummyRule::new("test.func")));
130        assert_eq!(registry.len(), 1);
131
132        // Register again with same name - should replace
133        registry.register(Arc::new(DummyRule::new("test.func")));
134        assert_eq!(registry.len(), 1);
135    }
136
137    #[test]
138    fn test_registry_registered_functions() {
139        let mut registry = RewriteRegistry::new();
140
141        registry.register(Arc::new(DummyRule::new("func1")));
142        registry.register(Arc::new(DummyRule::new("func2")));
143        registry.register(Arc::new(DummyRule::new("func3")));
144
145        let functions = registry.registered_functions();
146        assert_eq!(functions.len(), 3);
147        assert!(functions.contains(&"func1".to_string()));
148        assert!(functions.contains(&"func2".to_string()));
149        assert!(functions.contains(&"func3".to_string()));
150    }
151}