thread_rule_engine/
label.rs1use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use std::ops::Range;
10use thread_ast_engine::{Doc, Node, NodeMatch};
11use thread_utilities::RapidMap;
12
13#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema, PartialEq, Eq)]
14#[serde(rename_all = "camelCase")]
15pub enum LabelStyle {
16 Primary,
18 Secondary,
20}
21
22#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema)]
23pub struct LabelConfig {
24 pub style: LabelStyle,
25 pub message: Option<String>,
26}
27
28#[derive(Clone)]
33pub struct Label<'r, 't, D: Doc> {
34 pub style: LabelStyle,
35 pub message: Option<&'r str>,
36 pub start_node: Node<'t, D>,
37 pub end_node: Node<'t, D>,
38}
39
40impl<'t, D: Doc> Label<'_, 't, D> {
41 fn primary(n: &Node<'t, D>) -> Self {
42 Self {
43 style: LabelStyle::Primary,
44 start_node: n.clone(),
45 end_node: n.clone(),
46 message: None,
47 }
48 }
49 fn secondary(n: &Node<'t, D>) -> Self {
50 Self {
51 style: LabelStyle::Secondary,
52 start_node: n.clone(),
53 end_node: n.clone(),
54 message: None,
55 }
56 }
57
58 pub fn range(&self) -> Range<usize> {
59 let start = self.start_node.range().start;
60 let end = self.end_node.range().end;
61 start..end
62 }
63}
64
65pub fn get_labels_from_config<'r, 't, D: Doc>(
66 config: &'r RapidMap<String, LabelConfig>,
67 node_match: &NodeMatch<'t, D>,
68) -> Vec<Label<'r, 't, D>> {
69 let env = node_match.get_env();
70 config
71 .iter()
72 .filter_map(|(var, conf)| {
73 let (start, end) = if let Some(n) = env.get_match(var) {
74 (n.clone(), n.clone())
75 } else {
76 let ns = env.get_multiple_matches(var);
77 let start = ns.first()?.clone();
78 let end = ns.last()?.clone();
79 (start, end)
80 };
81 Some(Label {
82 style: conf.style.clone(),
83 message: conf.message.as_deref(),
84 start_node: start,
85 end_node: end,
86 })
87 })
88 .collect()
89}
90
91pub fn get_default_labels<'t, D: Doc>(n: &NodeMatch<'t, D>) -> Vec<Label<'static, 't, D>> {
92 let mut ret = vec![Label::primary(n)];
93 if let Some(secondary) = n.get_env().get_labels("secondary") {
94 ret.extend(secondary.iter().map(Label::secondary));
95 }
96 ret
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102 use crate::test::TypeScript;
103 use thread_ast_engine::matcher::Pattern;
104 use thread_ast_engine::tree_sitter::LanguageExt;
105 use thread_ast_engine::tree_sitter::StrDoc;
106
107 #[test]
108 fn test_label_primary_secondary() {
109 let doc = TypeScript::Tsx.ast_grep("let a = 1;");
110 let root = doc.root();
111 let label = Label::primary(&root);
112 assert_eq!(label.style, LabelStyle::Primary);
113 assert_eq!(label.range(), root.range());
114 let label2 = Label::<'_, '_, StrDoc<TypeScript>>::secondary(&root);
115 assert_eq!(label2.style, LabelStyle::Secondary);
116 }
117
118 #[test]
119 fn test_get_labels_from_config_single() {
120 let doc = TypeScript::Tsx.ast_grep("let foo = 42;");
121 let pattern = Pattern::try_new("let $A = $B;", &TypeScript::Tsx).unwrap();
122 let m = doc.root().find(pattern).unwrap();
123 let mut config = thread_utilities::RapidMap::default();
124 config.insert(
125 "A".to_string(),
126 LabelConfig {
127 style: LabelStyle::Primary,
128 message: Some("var label".to_string()),
129 },
130 );
131 let labels = get_labels_from_config(&config, &m);
132 assert_eq!(labels.len(), 1);
133 assert_eq!(labels[0].style, LabelStyle::Primary);
134 assert_eq!(labels[0].message, Some("var label"));
135 }
136
137 #[test]
138 fn test_get_labels_from_config_multiple() {
139 let doc = TypeScript::Tsx.ast_grep("let foo = 42, bar = 99;");
140 let pattern = Pattern::try_new("let $A = $B, $C = $D;", &TypeScript::Tsx).unwrap();
141 let m = doc.root().find(pattern).unwrap();
142 let mut config = thread_utilities::RapidMap::default();
143 config.insert(
144 "A".to_string(),
145 LabelConfig {
146 style: LabelStyle::Secondary,
147 message: None,
148 },
149 );
150 let labels = get_labels_from_config(&config, &m);
151 assert_eq!(labels.len(), 1);
152 assert_eq!(labels[0].style, LabelStyle::Secondary);
153 }
154
155 #[test]
156 fn test_get_default_labels() {
157 let doc = TypeScript::Tsx.ast_grep("let foo = 42;");
158 let pattern = Pattern::try_new("let $A = $B;", &TypeScript::Tsx).unwrap();
159 let m = doc.root().find(pattern).unwrap();
160 let labels = get_default_labels(&m);
161 assert!(!labels.is_empty());
162 assert_eq!(labels[0].style, LabelStyle::Primary);
163 }
164}