Skip to main content

thread_rule_engine/
label.rs

1// SPDX-FileCopyrightText: 2022 Herrington Darkholme <2883231+HerringtonDarkholme@users.noreply.github.com>
2// SPDX-FileCopyrightText: 2025 Knitli Inc. <knitli@knit.li>
3// SPDX-FileContributor: Adam Poulemanos <adam@knit.li>
4//
5// SPDX-License-Identifier: AGPL-3.0-or-later AND MIT
6
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use std::ops::Range;
10use thread_ast_engine::{Doc, Node, NodeMatch};
11use thread_utilities::RapidMap;
12
13#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema, PartialEq, Eq)]
14#[serde(rename_all = "camelCase")]
15pub enum LabelStyle {
16    /// Labels that describe the primary cause of a diagnostic.
17    Primary,
18    /// Labels that provide additional context for a diagnostic.
19    Secondary,
20}
21
22#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema)]
23pub struct LabelConfig {
24    pub style: LabelStyle,
25    pub message: Option<String>,
26}
27
28/// A label is a way to mark a specific part of the code with a styled message.
29/// It is used to provide diagnostic information in LSP or CLI.
30/// 'r represents a lifetime for the message string from `rule`.
31/// 't represents a lifetime for the node from a ast `tree`.
32#[derive(Clone)]
33pub struct Label<'r, 't, D: Doc> {
34    pub style: LabelStyle,
35    pub message: Option<&'r str>,
36    pub start_node: Node<'t, D>,
37    pub end_node: Node<'t, D>,
38}
39
40impl<'t, D: Doc> Label<'_, 't, D> {
41    fn primary(n: &Node<'t, D>) -> Self {
42        Self {
43            style: LabelStyle::Primary,
44            start_node: n.clone(),
45            end_node: n.clone(),
46            message: None,
47        }
48    }
49    fn secondary(n: &Node<'t, D>) -> Self {
50        Self {
51            style: LabelStyle::Secondary,
52            start_node: n.clone(),
53            end_node: n.clone(),
54            message: None,
55        }
56    }
57
58    pub fn range(&self) -> Range<usize> {
59        let start = self.start_node.range().start;
60        let end = self.end_node.range().end;
61        start..end
62    }
63}
64
65pub fn get_labels_from_config<'r, 't, D: Doc>(
66    config: &'r RapidMap<String, LabelConfig>,
67    node_match: &NodeMatch<'t, D>,
68) -> Vec<Label<'r, 't, D>> {
69    let env = node_match.get_env();
70    config
71        .iter()
72        .filter_map(|(var, conf)| {
73            let (start, end) = if let Some(n) = env.get_match(var) {
74                (n.clone(), n.clone())
75            } else {
76                let ns = env.get_multiple_matches(var);
77                let start = ns.first()?.clone();
78                let end = ns.last()?.clone();
79                (start, end)
80            };
81            Some(Label {
82                style: conf.style.clone(),
83                message: conf.message.as_deref(),
84                start_node: start,
85                end_node: end,
86            })
87        })
88        .collect()
89}
90
91pub fn get_default_labels<'t, D: Doc>(n: &NodeMatch<'t, D>) -> Vec<Label<'static, 't, D>> {
92    let mut ret = vec![Label::primary(n)];
93    if let Some(secondary) = n.get_env().get_labels("secondary") {
94        ret.extend(secondary.iter().map(Label::secondary));
95    }
96    ret
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102    use crate::test::TypeScript;
103    use thread_ast_engine::matcher::Pattern;
104    use thread_ast_engine::tree_sitter::LanguageExt;
105    use thread_ast_engine::tree_sitter::StrDoc;
106
107    #[test]
108    fn test_label_primary_secondary() {
109        let doc = TypeScript::Tsx.ast_grep("let a = 1;");
110        let root = doc.root();
111        let label = Label::primary(&root);
112        assert_eq!(label.style, LabelStyle::Primary);
113        assert_eq!(label.range(), root.range());
114        let label2 = Label::<'_, '_, StrDoc<TypeScript>>::secondary(&root);
115        assert_eq!(label2.style, LabelStyle::Secondary);
116    }
117
118    #[test]
119    fn test_get_labels_from_config_single() {
120        let doc = TypeScript::Tsx.ast_grep("let foo = 42;");
121        let pattern = Pattern::try_new("let $A = $B;", &TypeScript::Tsx).unwrap();
122        let m = doc.root().find(pattern).unwrap();
123        let mut config = thread_utilities::RapidMap::default();
124        config.insert(
125            "A".to_string(),
126            LabelConfig {
127                style: LabelStyle::Primary,
128                message: Some("var label".to_string()),
129            },
130        );
131        let labels = get_labels_from_config(&config, &m);
132        assert_eq!(labels.len(), 1);
133        assert_eq!(labels[0].style, LabelStyle::Primary);
134        assert_eq!(labels[0].message, Some("var label"));
135    }
136
137    #[test]
138    fn test_get_labels_from_config_multiple() {
139        let doc = TypeScript::Tsx.ast_grep("let foo = 42, bar = 99;");
140        let pattern = Pattern::try_new("let $A = $B, $C = $D;", &TypeScript::Tsx).unwrap();
141        let m = doc.root().find(pattern).unwrap();
142        let mut config = thread_utilities::RapidMap::default();
143        config.insert(
144            "A".to_string(),
145            LabelConfig {
146                style: LabelStyle::Secondary,
147                message: None,
148            },
149        );
150        let labels = get_labels_from_config(&config, &m);
151        assert_eq!(labels.len(), 1);
152        assert_eq!(labels[0].style, LabelStyle::Secondary);
153    }
154
155    #[test]
156    fn test_get_default_labels() {
157        let doc = TypeScript::Tsx.ast_grep("let foo = 42;");
158        let pattern = Pattern::try_new("let $A = $B;", &TypeScript::Tsx).unwrap();
159        let m = doc.root().find(pattern).unwrap();
160        let labels = get_default_labels(&m);
161        assert!(!labels.is_empty());
162        assert_eq!(labels[0].style, LabelStyle::Primary);
163    }
164}