thread_rule_engine/rule/
referent_rule.rs1use crate::{Rule, RuleCore};
8
9use thread_ast_engine::meta_var::MetaVarEnv;
10use thread_ast_engine::{Doc, Matcher, Node};
11
12use bit_set::BitSet;
13use thiserror::Error;
14
15use std::borrow::Cow;
16use std::sync::{Arc, Weak};
17use thread_utilities::{RapidMap, RapidSet, set_with_capacity};
18
19#[derive(Debug)]
20pub struct Registration<R>(Arc<RapidMap<String, R>>);
21
22impl<R> Clone for Registration<R> {
23 fn clone(&self) -> Self {
24 Self(self.0.clone())
25 }
26}
27
28impl<R> Registration<R> {
29 #[allow(clippy::mut_from_ref)]
30 fn write(&self) -> &mut RapidMap<String, R> {
31 unsafe { &mut *(Arc::as_ptr(&self.0) as *mut RapidMap<String, R>) }
34 }
35}
36pub type GlobalRules = Registration<RuleCore>;
37
38impl GlobalRules {
39 pub fn insert(&self, id: &str, rule: RuleCore) -> Result<(), ReferentRuleError> {
40 let map = self.write();
41 if map.contains_key(id) {
42 return Err(ReferentRuleError::DuplicateRule(id.into()));
43 }
44 map.insert(id.to_string(), rule);
45 let _rule = map.get(id).unwrap();
46 Ok(())
47 }
48}
49
50impl<R> Default for Registration<R> {
51 fn default() -> Self {
52 Self(Default::default())
53 }
54}
55
56#[derive(Clone, Debug, Default)]
57pub struct RuleRegistration {
58 local: Registration<Rule>,
60 global: Registration<RuleCore>,
62 rewriters: Registration<RuleCore>,
64}
65
66impl RuleRegistration {
68 pub fn get_rewriters(&self) -> &RapidMap<String, RuleCore> {
69 &self.rewriters.0
70 }
71
72 pub fn from_globals(global: &GlobalRules) -> Self {
73 Self {
74 local: Default::default(),
75 global: global.clone(),
76 rewriters: Default::default(),
77 }
78 }
79
80 fn get_ref(&self) -> RegistrationRef {
81 let local = Arc::downgrade(&self.local.0);
82 let global = Arc::downgrade(&self.global.0);
83 RegistrationRef { local, global }
84 }
85
86 pub(crate) fn insert_local(&self, id: &str, rule: Rule) -> Result<(), ReferentRuleError> {
87 if rule.check_cyclic(id) {
88 return Err(ReferentRuleError::CyclicRule(id.into()));
89 }
90 let map = self.local.write();
91 if map.contains_key(id) {
92 return Err(ReferentRuleError::DuplicateRule(id.into()));
93 }
94 map.insert(id.to_string(), rule);
95 Ok(())
96 }
97
98 pub(crate) fn insert_rewriter(&self, id: &str, rewriter: RuleCore) {
99 self.rewriters.insert(id, rewriter).expect("should work");
100 }
101
102 pub(crate) fn get_local_util_vars(&self) -> RapidSet<&str> {
103 let utils = &self.local.0;
104 let size = size_of_val(utils);
105 if size == 0 {
106 return RapidSet::default();
107 }
108 let mut ret = set_with_capacity(size);
110 for rule in utils.values() {
111 for v in rule.defined_vars() {
112 ret.insert(v);
113 }
114 }
115 ret
116 }
117}
118
119#[derive(Clone, Debug)]
122struct RegistrationRef {
123 local: Weak<RapidMap<String, Rule>>,
124 global: Weak<RapidMap<String, RuleCore>>,
125}
126impl RegistrationRef {
127 fn get_local(&self) -> Arc<RapidMap<String, Rule>> {
128 self.local
129 .upgrade()
130 .expect("Rule Registration must be kept alive")
131 }
132 fn get_global(&self) -> Arc<RapidMap<String, RuleCore>> {
133 self.global
134 .upgrade()
135 .expect("Rule Registration must be kept alive")
136 }
137}
138
139#[derive(Error, Debug)]
140pub enum ReferentRuleError {
141 #[error("Rule `{0}` is not defined.")]
142 UndefinedUtil(String),
143 #[error("Duplicate rule id `{0}` is found.")]
144 DuplicateRule(String),
145 #[error("Rule `{0}` has a cyclic dependency in its `matches` sub-rule.")]
146 CyclicRule(String),
147}
148
149#[derive(Clone, Debug)]
150pub struct ReferentRule {
151 pub(crate) rule_id: String,
152 reg_ref: RegistrationRef,
153}
154
155impl ReferentRule {
156 pub fn try_new(
157 rule_id: String,
158 registration: &RuleRegistration,
159 ) -> Result<Self, ReferentRuleError> {
160 Ok(Self {
161 reg_ref: registration.get_ref(),
162 rule_id,
163 })
164 }
165
166 fn eval_local<F, T>(&self, func: F) -> Option<T>
167 where
168 F: FnOnce(&Rule) -> T,
169 {
170 let rules = self.reg_ref.get_local();
171 let rule = rules.get(&self.rule_id)?;
172 Some(func(rule))
173 }
174
175 fn eval_global<F, T>(&self, func: F) -> Option<T>
176 where
177 F: FnOnce(&RuleCore) -> T,
178 {
179 let rules = self.reg_ref.get_global();
180 let rule = rules.get(&self.rule_id)?;
181 Some(func(rule))
182 }
183
184 pub(super) fn verify_util(&self) -> Result<(), ReferentRuleError> {
185 let rules = self.reg_ref.get_local();
186 if rules.contains_key(&self.rule_id) {
187 return Ok(());
188 }
189 let rules = self.reg_ref.get_global();
190 if rules.contains_key(&self.rule_id) {
191 return Ok(());
192 }
193 Err(ReferentRuleError::UndefinedUtil(self.rule_id.clone()))
194 }
195}
196
197impl Matcher for ReferentRule {
198 fn match_node_with_env<'tree, D: Doc>(
199 &self,
200 node: Node<'tree, D>,
201 env: &mut Cow<MetaVarEnv<'tree, D>>,
202 ) -> Option<Node<'tree, D>> {
203 self.eval_local(|r| r.match_node_with_env(node.clone(), env))
204 .or_else(|| self.eval_global(|r| r.match_node_with_env(node, env)))
205 .flatten()
206 }
207 fn potential_kinds(&self) -> Option<BitSet> {
208 self.eval_local(|r| {
209 debug_assert!(!r.check_cyclic(&self.rule_id), "no cyclic rule allowed");
210 r.potential_kinds()
211 })
212 .or_else(|| {
213 self.eval_global(|r| {
214 debug_assert!(!r.check_cyclic(&self.rule_id), "no cyclic rule allowed");
215 r.potential_kinds()
216 })
217 })
218 .flatten()
219 }
220}
221
222#[cfg(test)]
223mod test {
224 use super::*;
225 use crate::rule::Rule;
226 use crate::test::TypeScript as TS;
227 use thread_ast_engine::Pattern;
228 use thread_ast_engine::ops as o;
229
230 type Result = std::result::Result<(), ReferentRuleError>;
231
232 #[test]
233 fn test_cyclic_error() -> Result {
234 let registration = RuleRegistration::default();
235 let rule = ReferentRule::try_new("test".into(), ®istration)?;
236 let rule = Rule::Matches(rule);
237 let error = registration.insert_local("test", rule);
238 assert!(matches!(error, Err(ReferentRuleError::CyclicRule(_))));
239 Ok(())
240 }
241
242 #[test]
243 fn test_cyclic_all() -> Result {
244 let registration = RuleRegistration::default();
245 let rule = ReferentRule::try_new("test".into(), ®istration)?;
246 let rule = Rule::All(o::All::new(std::iter::once(Rule::Matches(rule))));
247 let error = registration.insert_local("test", rule);
248 assert!(matches!(error, Err(ReferentRuleError::CyclicRule(_))));
249 Ok(())
250 }
251
252 #[test]
253 fn test_cyclic_not() -> Result {
254 let registration = RuleRegistration::default();
255 let rule = ReferentRule::try_new("test".into(), ®istration)?;
256 let rule = Rule::Not(Box::new(o::Not::new(Rule::Matches(rule))));
257 let error = registration.insert_local("test", rule);
258 assert!(matches!(error, Err(ReferentRuleError::CyclicRule(_))));
259 Ok(())
260 }
261
262 #[test]
263 fn test_success_rule() -> Result {
264 let registration = RuleRegistration::default();
265 let rule = ReferentRule::try_new("test".into(), ®istration)?;
266 let pattern = Rule::Pattern(Pattern::new("some", &TS::Tsx));
267 let ret = registration.insert_local("test", pattern);
268 assert!(ret.is_ok());
269 assert!(rule.potential_kinds().is_some());
270 Ok(())
271 }
272}