Skip to main content

thread_rule_engine/rule/
referent_rule.rs

1// SPDX-FileCopyrightText: 2022 Herrington Darkholme <2883231+HerringtonDarkholme@users.noreply.github.com>
2// SPDX-FileCopyrightText: 2025 Knitli Inc. <knitli@knit.li>
3// SPDX-FileContributor: Adam Poulemanos <adam@knit.li>
4//
5// SPDX-License-Identifier: AGPL-3.0-or-later AND MIT
6
7use crate::{Rule, RuleCore};
8
9use thread_ast_engine::meta_var::MetaVarEnv;
10use thread_ast_engine::{Doc, Matcher, Node};
11
12use bit_set::BitSet;
13use thiserror::Error;
14
15use std::borrow::Cow;
16use std::sync::{Arc, Weak};
17use thread_utilities::{RapidMap, RapidSet, set_with_capacity};
18
19#[derive(Debug)]
20pub struct Registration<R>(Arc<RapidMap<String, R>>);
21
22impl<R> Clone for Registration<R> {
23    fn clone(&self) -> Self {
24        Self(self.0.clone())
25    }
26}
27
28impl<R> Registration<R> {
29    #[allow(clippy::mut_from_ref)]
30    fn write(&self) -> &mut RapidMap<String, R> {
31        // SAFETY: `write` will only be called during initialization and
32        // it only insert new item to the RapidMap. It is safe to cast the raw ptr.
33        unsafe { &mut *(Arc::as_ptr(&self.0) as *mut RapidMap<String, R>) }
34    }
35}
36pub type GlobalRules = Registration<RuleCore>;
37
38impl GlobalRules {
39    pub fn insert(&self, id: &str, rule: RuleCore) -> Result<(), ReferentRuleError> {
40        let map = self.write();
41        if map.contains_key(id) {
42            return Err(ReferentRuleError::DuplicateRule(id.into()));
43        }
44        map.insert(id.to_string(), rule);
45        let _rule = map.get(id).unwrap();
46        Ok(())
47    }
48}
49
50impl<R> Default for Registration<R> {
51    fn default() -> Self {
52        Self(Default::default())
53    }
54}
55
56#[derive(Clone, Debug, Default)]
57pub struct RuleRegistration {
58    /// utility rule to every RuleCore, every sub-rule has its own local utility
59    local: Registration<Rule>,
60    /// global rules are shared by all RuleConfigs. It is a singleton.
61    global: Registration<RuleCore>,
62    /// Every RuleConfig has its own rewriters. But sub-rules share parent's rewriters.
63    rewriters: Registration<RuleCore>,
64}
65
66// these are shit code
67impl RuleRegistration {
68    pub fn get_rewriters(&self) -> &RapidMap<String, RuleCore> {
69        &self.rewriters.0
70    }
71
72    pub fn from_globals(global: &GlobalRules) -> Self {
73        Self {
74            local: Default::default(),
75            global: global.clone(),
76            rewriters: Default::default(),
77        }
78    }
79
80    fn get_ref(&self) -> RegistrationRef {
81        let local = Arc::downgrade(&self.local.0);
82        let global = Arc::downgrade(&self.global.0);
83        RegistrationRef { local, global }
84    }
85
86    pub(crate) fn insert_local(&self, id: &str, rule: Rule) -> Result<(), ReferentRuleError> {
87        if rule.check_cyclic(id) {
88            return Err(ReferentRuleError::CyclicRule(id.into()));
89        }
90        let map = self.local.write();
91        if map.contains_key(id) {
92            return Err(ReferentRuleError::DuplicateRule(id.into()));
93        }
94        map.insert(id.to_string(), rule);
95        Ok(())
96    }
97
98    pub(crate) fn insert_rewriter(&self, id: &str, rewriter: RuleCore) {
99        self.rewriters.insert(id, rewriter).expect("should work");
100    }
101
102    pub(crate) fn get_local_util_vars(&self) -> RapidSet<&str> {
103        let utils = &self.local.0;
104        let size = size_of_val(utils);
105        if size == 0 {
106            return RapidSet::default();
107        }
108        // this gets closer to the actual size
109        let mut ret = set_with_capacity(size);
110        for rule in utils.values() {
111            for v in rule.defined_vars() {
112                ret.insert(v);
113            }
114        }
115        ret
116    }
117}
118
119/// RegistrationRef must use Weak pointer to avoid
120/// cyclic reference in RuleRegistration
121#[derive(Clone, Debug)]
122struct RegistrationRef {
123    local: Weak<RapidMap<String, Rule>>,
124    global: Weak<RapidMap<String, RuleCore>>,
125}
126impl RegistrationRef {
127    fn get_local(&self) -> Arc<RapidMap<String, Rule>> {
128        self.local
129            .upgrade()
130            .expect("Rule Registration must be kept alive")
131    }
132    fn get_global(&self) -> Arc<RapidMap<String, RuleCore>> {
133        self.global
134            .upgrade()
135            .expect("Rule Registration must be kept alive")
136    }
137}
138
139#[derive(Error, Debug)]
140pub enum ReferentRuleError {
141    #[error("Rule `{0}` is not defined.")]
142    UndefinedUtil(String),
143    #[error("Duplicate rule id `{0}` is found.")]
144    DuplicateRule(String),
145    #[error("Rule `{0}` has a cyclic dependency in its `matches` sub-rule.")]
146    CyclicRule(String),
147}
148
149#[derive(Clone, Debug)]
150pub struct ReferentRule {
151    pub(crate) rule_id: String,
152    reg_ref: RegistrationRef,
153}
154
155impl ReferentRule {
156    pub fn try_new(
157        rule_id: String,
158        registration: &RuleRegistration,
159    ) -> Result<Self, ReferentRuleError> {
160        Ok(Self {
161            reg_ref: registration.get_ref(),
162            rule_id,
163        })
164    }
165
166    fn eval_local<F, T>(&self, func: F) -> Option<T>
167    where
168        F: FnOnce(&Rule) -> T,
169    {
170        let rules = self.reg_ref.get_local();
171        let rule = rules.get(&self.rule_id)?;
172        Some(func(rule))
173    }
174
175    fn eval_global<F, T>(&self, func: F) -> Option<T>
176    where
177        F: FnOnce(&RuleCore) -> T,
178    {
179        let rules = self.reg_ref.get_global();
180        let rule = rules.get(&self.rule_id)?;
181        Some(func(rule))
182    }
183
184    pub(super) fn verify_util(&self) -> Result<(), ReferentRuleError> {
185        let rules = self.reg_ref.get_local();
186        if rules.contains_key(&self.rule_id) {
187            return Ok(());
188        }
189        let rules = self.reg_ref.get_global();
190        if rules.contains_key(&self.rule_id) {
191            return Ok(());
192        }
193        Err(ReferentRuleError::UndefinedUtil(self.rule_id.clone()))
194    }
195}
196
197impl Matcher for ReferentRule {
198    fn match_node_with_env<'tree, D: Doc>(
199        &self,
200        node: Node<'tree, D>,
201        env: &mut Cow<MetaVarEnv<'tree, D>>,
202    ) -> Option<Node<'tree, D>> {
203        self.eval_local(|r| r.match_node_with_env(node.clone(), env))
204            .or_else(|| self.eval_global(|r| r.match_node_with_env(node, env)))
205            .flatten()
206    }
207    fn potential_kinds(&self) -> Option<BitSet> {
208        self.eval_local(|r| {
209            debug_assert!(!r.check_cyclic(&self.rule_id), "no cyclic rule allowed");
210            r.potential_kinds()
211        })
212        .or_else(|| {
213            self.eval_global(|r| {
214                debug_assert!(!r.check_cyclic(&self.rule_id), "no cyclic rule allowed");
215                r.potential_kinds()
216            })
217        })
218        .flatten()
219    }
220}
221
222#[cfg(test)]
223mod test {
224    use super::*;
225    use crate::rule::Rule;
226    use crate::test::TypeScript as TS;
227    use thread_ast_engine::Pattern;
228    use thread_ast_engine::ops as o;
229
230    type Result = std::result::Result<(), ReferentRuleError>;
231
232    #[test]
233    fn test_cyclic_error() -> Result {
234        let registration = RuleRegistration::default();
235        let rule = ReferentRule::try_new("test".into(), &registration)?;
236        let rule = Rule::Matches(rule);
237        let error = registration.insert_local("test", rule);
238        assert!(matches!(error, Err(ReferentRuleError::CyclicRule(_))));
239        Ok(())
240    }
241
242    #[test]
243    fn test_cyclic_all() -> Result {
244        let registration = RuleRegistration::default();
245        let rule = ReferentRule::try_new("test".into(), &registration)?;
246        let rule = Rule::All(o::All::new(std::iter::once(Rule::Matches(rule))));
247        let error = registration.insert_local("test", rule);
248        assert!(matches!(error, Err(ReferentRuleError::CyclicRule(_))));
249        Ok(())
250    }
251
252    #[test]
253    fn test_cyclic_not() -> Result {
254        let registration = RuleRegistration::default();
255        let rule = ReferentRule::try_new("test".into(), &registration)?;
256        let rule = Rule::Not(Box::new(o::Not::new(Rule::Matches(rule))));
257        let error = registration.insert_local("test", rule);
258        assert!(matches!(error, Err(ReferentRuleError::CyclicRule(_))));
259        Ok(())
260    }
261
262    #[test]
263    fn test_success_rule() -> Result {
264        let registration = RuleRegistration::default();
265        let rule = ReferentRule::try_new("test".into(), &registration)?;
266        let pattern = Rule::Pattern(Pattern::new("some", &TS::Tsx));
267        let ret = registration.insert_local("test", pattern);
268        assert!(ret.is_ok());
269        assert!(rule.potential_kinds().is_some());
270        Ok(())
271    }
272}