zee_highlight/
lib.rs

1mod error;
2mod selector;
3
4use fnv::FnvHashMap;
5use serde_derive::{self, Deserialize, Serialize};
6use std::{cmp, collections::HashMap, convert::TryFrom};
7
8use error::Result;
9use tree_sitter::Language;
10
11use crate::selector::{map_node_kind_names, Selector};
12
13pub use crate::selector::SelectorNodeId;
14
15#[derive(Clone, Debug, Deserialize, Serialize)]
16pub struct HighlightRules {
17    name: String,
18    node_id_to_selector_id: FnvHashMap<u16, SelectorNodeId>,
19
20    #[serde(default)]
21    rules: Vec<HighlightRule>,
22}
23
24#[derive(Clone, Debug, Deserialize, Serialize)]
25pub struct HighlightRule {
26    selectors: Vec<Selector>,
27    scope: ScopePattern,
28}
29
30impl HighlightRules {
31    #[inline]
32    pub fn get_selector_node_id(&self, node_kind_id: u16) -> SelectorNodeId {
33        self.node_id_to_selector_id
34            .get(&node_kind_id)
35            .copied()
36            .unwrap_or_else(|| {
37                SelectorNodeId(u16::try_from(self.node_id_to_selector_id.len()).unwrap())
38            })
39    }
40
41    #[inline]
42    pub fn matches(
43        &self,
44        node_stack: &[SelectorNodeId],
45        nth_children: &[u16],
46        content: &str,
47    ) -> Option<&Scope> {
48        if node_stack.is_empty() {
49            return None;
50        }
51
52        let mut distance_to_match = std::usize::MAX;
53        let mut num_nodes_match = 0;
54        let mut scope_pattern = None;
55        for rule in self.rules.iter() {
56            let rule_scope = match rule.scope.matches(content) {
57                Some(scope) => scope,
58                None => continue,
59            };
60
61            for selector in rule.selectors.iter() {
62                let selector_node_kinds = selector.node_kinds();
63                let selector_nth_children = selector.nth_children();
64
65                // eprintln!("NST {:?} {:?}", node_stack, nth_children);
66                // eprintln!("SEL {:?} {:?}", selector_node_kinds, selector_nth_children);
67
68                assert!(!selector_node_kinds.is_empty());
69                if selector_node_kinds.len() > node_stack.len() {
70                    continue;
71                }
72
73                // TODO: Are for loops over inclusive ranges slow?
74                for start in 0..=cmp::min(
75                    node_stack.len().saturating_sub(selector_node_kinds.len()),
76                    distance_to_match,
77                ) {
78                    let span_range = || start..start + selector_node_kinds.len();
79
80                    // Does the selector match the current node and its ancestors?
81                    if selector_node_kinds
82                        != &node_stack[start..(start + selector_node_kinds.len())]
83                    {
84                        continue;
85                    }
86
87                    // Are the `nth-child` constrains also satisfied?
88                    let nth_child_not_satisfied = selector_nth_children
89                        .iter()
90                        .zip(nth_children[span_range()].iter())
91                        .any(|(&nth_child_selector, &node_sibling_index)| {
92                            nth_child_selector >= 0
93                                && nth_child_selector as u16 != node_sibling_index
94                        });
95                    if nth_child_not_satisfied {
96                        continue;
97                    }
98
99                    // Is the selector more specific than the most specific
100                    // match we've found so far?
101                    if start == distance_to_match && num_nodes_match > selector_node_kinds.len() {
102                        break;
103                    }
104
105                    assert!(start <= distance_to_match);
106                    // eprintln!(
107                    //     "!!D {} -> {} | N {} -> {}",
108                    //     distance_to_match,
109                    //     start,
110                    //     num_nodes_match,
111                    //     selector_node_kinds.len()
112                    // );
113
114                    distance_to_match = start;
115                    num_nodes_match = selector_node_kinds.len();
116                    scope_pattern = Some(rule_scope);
117                    break;
118                }
119            }
120        }
121
122        scope_pattern
123    }
124}
125
126#[derive(Clone, Debug, Serialize, Deserialize)]
127pub struct RawHighlightRules {
128    name: String,
129
130    #[serde(default)]
131    pub scopes: HashMap<String, ScopePattern>,
132}
133
134impl RawHighlightRules {
135    fn compile(self, language: Language) -> Result<HighlightRules> {
136        let (node_name_to_selector_id, node_id_to_selector_id) =
137            build_node_to_selector_id_maps(language);
138        let RawHighlightRules { name, scopes } = self;
139
140        scopes
141            .into_iter()
142            .map(|(selector_str, scope)| {
143                let selectors = selector::parse(&selector_str)?;
144                let selectors = selectors
145                    .into_iter()
146                    .map(|selector| map_node_kind_names(&node_name_to_selector_id, selector))
147                    .collect::<Result<Vec<_>>>()?;
148                Ok(HighlightRule { selectors, scope })
149            })
150            .collect::<Result<Vec<_>>>()
151            .map(|rules| HighlightRules {
152                name,
153                rules,
154                node_id_to_selector_id,
155            })
156    }
157}
158
159fn build_node_to_selector_id_maps(
160    language: Language,
161) -> (
162    FnvHashMap<&'static str, SelectorNodeId>,
163    FnvHashMap<u16, SelectorNodeId>,
164) {
165    let mut node_name_to_selector_id =
166        FnvHashMap::with_capacity_and_hasher(language.node_kind_count(), Default::default());
167    let mut node_id_to_selector_id =
168        FnvHashMap::with_capacity_and_hasher(language.node_kind_count(), Default::default());
169
170    let node_id_range =
171        0..u16::try_from(language.node_kind_count()).expect("node_kind_count() should fit in u16");
172    for node_id in node_id_range {
173        let node_name = language
174            .node_kind_for_id(node_id)
175            .expect("node kind available for node_id in range");
176        let next_selector_id =
177            SelectorNodeId(u16::try_from(node_name_to_selector_id.len()).unwrap());
178        let selector_id = node_name_to_selector_id
179            .entry(node_name)
180            .or_insert_with(|| next_selector_id);
181        node_id_to_selector_id.insert(node_id, *selector_id);
182    }
183
184    // log::debug!(
185    //     "NKC: {}, name->sid: {}, nid->sid: {}",
186    //     language.node_kind_count(),
187    //     node_name_to_selector_id.len(),
188    //     node_id_to_selector_id.len(),
189    // );
190
191    (node_name_to_selector_id, node_id_to_selector_id)
192}
193
194#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
195#[serde(untagged)]
196pub enum ScopePattern {
197    All(Scope),
198    Exact {
199        exact: String,
200        scopes: Scope,
201    },
202    Regex {
203        #[serde(rename = "match")]
204        regex: Regex,
205        scopes: Scope,
206    },
207    Vec(Vec<ScopePattern>),
208}
209
210#[derive(Clone, Debug, Deserialize, Serialize)]
211pub struct Regex(#[serde(with = "serde_regex")] regex::Regex);
212
213impl Regex {
214    fn is_match(&self, text: &str) -> bool {
215        self.0.is_match(text)
216    }
217}
218
219impl PartialEq for Regex {
220    fn eq(&self, other: &Self) -> bool {
221        self.0.as_str() == other.0.as_str()
222    }
223}
224
225impl ScopePattern {
226    fn matches(&self, content: &str) -> Option<&Scope> {
227        match self {
228            ScopePattern::All(ref scopes) => Some(scopes),
229            ScopePattern::Exact {
230                ref exact,
231                ref scopes,
232            } if exact.as_str() == content => Some(scopes),
233            ScopePattern::Regex {
234                ref regex,
235                ref scopes,
236            } if regex.is_match(content) => Some(scopes),
237            ScopePattern::Vec(ref scope_patterns) => {
238                for scope_pattern in scope_patterns.iter() {
239                    let maybe_scope = scope_pattern.matches(content);
240                    if maybe_scope.is_some() {
241                        return maybe_scope;
242                    }
243                }
244                None
245            }
246            _ => None,
247        }
248    }
249}
250
251#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
252pub struct Scope(pub String);
253
254pub fn parse_rules_unwrap(language: Language, source: &str) -> HighlightRules {
255    let raw_rules =
256        serde_json::from_str::<RawHighlightRules>(source).expect("valid json file for rules");
257    let name = format!("valid rules for {}", raw_rules.name);
258    raw_rules.compile(language).expect(&name)
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use maplit::hashmap;
265
266    #[test]
267    fn deserialize_no_scopes() {
268        let style_str = r#"{"name": "Rust"}"#;
269        let expected = RawHighlightRules {
270            name: "Rust".into(),
271            scopes: Default::default(),
272        };
273        let actual: RawHighlightRules = serde_json::from_str(style_str).expect("valid json");
274        assert_eq!(expected.name, actual.name);
275    }
276
277    #[test]
278    fn deserialize_all_scope_types() {
279        let style_str = r#"{
280            "name": "Rust",
281            "scopes": {
282                "type_identifier": "support.type",
283                "\"let\"": {"exact": "let", "scopes": "keyword.control" }
284            }
285        }"#;
286        let expected = RawHighlightRules {
287            name: "Rust".into(),
288            scopes: hashmap! {
289                "type_identifier".into() => ScopePattern::All(Scope("support.type".into())),
290                "\"let\"".into() => ScopePattern::Exact {
291                    exact: "let".into(),
292                    scopes: Scope("keyword.control".into())
293                },
294            },
295        };
296        let actual: RawHighlightRules = serde_json::from_str(style_str).expect("valid json");
297        assert_eq!(expected.name, actual.name);
298        assert_eq!(expected.scopes, actual.scopes);
299    }
300}