1use anyhow::Result;
2use std::collections::HashMap;
3use tree_sitter::{Parser, Query, QueryCursor, StreamingIterator};
4
5use crate::lang::Lang;
6
7#[derive(Debug, Clone, Default)]
8pub struct ComplexityMetrics {
9 pub branches: i64,
10 pub loops: i64,
11 pub returns: i64,
12 pub max_nesting: i64,
13 pub unsafe_blocks: i64,
14}
15
16impl ComplexityMetrics {
17 pub fn total_complexity(&self) -> i64 {
18 self.branches + self.loops + self.returns
19 }
20
21 pub fn from_raw_fields(
22 branches: i64,
23 loops: i64,
24 returns: i64,
25 max_nesting: i64,
26 unsafe_blocks: i64,
27 ) -> Self {
28 Self {
29 branches,
30 loops,
31 returns,
32 max_nesting,
33 unsafe_blocks,
34 }
35 }
36}
37
38pub trait LanguageExtractor: Send + Sync {
39 fn lang(&self) -> Lang;
40 fn extract_complexity(&self, source: &[u8]) -> Result<ComplexityMetrics>;
41}
42
43struct BuiltinExtractor {
44 lang: Lang,
45}
46
47impl BuiltinExtractor {
48 fn complexity_query(&self) -> Option<&'static str> {
49 match self.lang {
50 #[cfg(feature = "lang-rust")]
51 Lang::Rust => Some(
52 r#"
53 (if_expression) @branch
54 (match_expression) @branch
55 (for_expression) @loop
56 (while_expression) @loop
57 (loop_expression) @loop
58 (return_expression) @return
59 (unsafe_block) @unsafe
60 "#,
61 ),
62 #[cfg(feature = "lang-python")]
63 Lang::Python => Some(
64 r#"
65 (if_statement) @branch
66 (elif_clause) @branch
67 (for_statement) @loop
68 (while_statement) @loop
69 (return_statement) @return
70 "#,
71 ),
72 #[cfg(feature = "lang-typescript")]
73 Lang::TypeScript | Lang::Tsx => Some(
74 r#"
75 (if_statement) @branch
76 (switch_statement) @branch
77 (ternary_expression) @branch
78 (for_statement) @loop
79 (for_in_statement) @loop
80 (while_statement) @loop
81 (do_statement) @loop
82 (return_statement) @return
83 "#,
84 ),
85 #[cfg(feature = "lang-javascript")]
86 Lang::JavaScript | Lang::Jsx => Some(
87 r#"
88 (if_statement) @branch
89 (switch_statement) @branch
90 (ternary_expression) @branch
91 (for_statement) @loop
92 (for_in_statement) @loop
93 (while_statement) @loop
94 (do_statement) @loop
95 (return_statement) @return
96 "#,
97 ),
98 #[cfg(feature = "lang-kotlin")]
99 Lang::Kotlin => Some(
100 r#"
101 (if_expression) @branch
102 (when_expression) @branch
103 (for_statement) @loop
104 (while_statement) @loop
105 (do_while_statement) @loop
106 (return_expression) @return
107 "#,
108 ),
109 _ => None,
110 }
111 }
112
113 fn compute_max_nesting(&self, source: &[u8]) -> i64 {
114 let ts_lang = self.lang.tree_sitter_language();
115 let mut parser = Parser::new();
116 if parser.set_language(&ts_lang).is_err() {
117 return 0;
118 }
119 let tree = match parser.parse(source, None) {
120 Some(t) => t,
121 None => return 0,
122 };
123 let mut max_depth: i64 = 0;
124 fn walk(node: tree_sitter::Node, depth: i64, max_depth: &mut i64) {
125 let kind = node.kind();
126 let is_scope = matches!(
127 kind,
128 "function_item"
129 | "function_definition"
130 | "function_declaration"
131 | "class_definition"
132 | "class_declaration"
133 | "impl_item"
134 | "if_expression"
135 | "if_statement"
136 | "for_expression"
137 | "for_statement"
138 | "while_expression"
139 | "while_statement"
140 | "loop_expression"
141 | "match_expression"
142 | "switch_statement"
143 | "when_expression"
144 | "block"
145 | "expression_list"
146 );
147 let child_depth = if is_scope { depth + 1 } else { depth };
148 if child_depth > *max_depth {
149 *max_depth = child_depth;
150 }
151 let mut cursor = node.walk();
152 for child in node.children(&mut cursor) {
153 walk(child, child_depth, max_depth);
154 }
155 }
156 walk(tree.root_node(), 0, &mut max_depth);
157 max_depth.max(0)
158 }
159}
160
161impl LanguageExtractor for BuiltinExtractor {
162 fn lang(&self) -> Lang {
163 self.lang
164 }
165
166 fn extract_complexity(&self, source: &[u8]) -> Result<ComplexityMetrics> {
167 let query_str = match self.complexity_query() {
168 Some(q) => q,
169 None => return Ok(ComplexityMetrics::default()),
170 };
171 let ts_lang = self.lang.tree_sitter_language();
172 let mut parser = Parser::new();
173 parser.set_language(&ts_lang)?;
174 let tree = parser
175 .parse(source, None)
176 .ok_or_else(|| anyhow::anyhow!("parse failed"))?;
177 let query = Query::new(&ts_lang, query_str)?;
178 let mut cursor = QueryCursor::new();
179 let mut metrics = ComplexityMetrics::default();
180
181 let capture_names: Vec<String> = query
182 .capture_names()
183 .iter()
184 .map(|s| s.to_string())
185 .collect();
186
187 let mut matches = cursor.matches(&query, tree.root_node(), source);
188 while let Some(m) = matches.next() {
189 for capture in m.captures {
190 let name = &capture_names[capture.index as usize];
191 match name.as_str() {
192 "branch" => metrics.branches += 1,
193 "loop" => metrics.loops += 1,
194 "return" => metrics.returns += 1,
195 "unsafe" => metrics.unsafe_blocks += 1,
196 _ => {}
197 }
198 }
199 }
200
201 metrics.max_nesting = self.compute_max_nesting(source);
202 Ok(metrics)
203 }
204}
205
206pub struct LanguageRegistry {
207 extractors: HashMap<String, Box<dyn LanguageExtractor>>,
208}
209
210impl LanguageRegistry {
211 pub fn new() -> Self {
212 let mut registry = Self {
213 extractors: HashMap::new(),
214 };
215 registry.register_builtins();
216 registry
217 }
218
219 fn register_builtins(&mut self) {
220 for lang in Lang::all() {
221 let ext = lang.name().to_string();
222 let extractor = BuiltinExtractor { lang };
223 self.extractors.insert(ext, Box::new(extractor));
224 }
225 }
226
227 pub fn register(&mut self, name: String, extractor: Box<dyn LanguageExtractor>) {
228 self.extractors.insert(name, extractor);
229 }
230
231 pub fn get(&self, lang_name: &str) -> Option<&dyn LanguageExtractor> {
232 self.extractors.get(lang_name).map(|e| e.as_ref())
233 }
234
235 pub fn extractor_for_extension(&self, ext: &str) -> Option<&dyn LanguageExtractor> {
236 let lang = Lang::from_extension(ext)?;
237 self.get(lang.name())
238 }
239
240 pub fn complexity_for_source(&self, lang: Lang, source: &[u8]) -> Result<ComplexityMetrics> {
241 let extractor = self.get(lang.name()).ok_or_else(|| {
242 anyhow::anyhow!("no extractor registered for language: {}", lang.name())
243 })?;
244 extractor.extract_complexity(source)
245 }
246
247 pub fn registered_languages(&self) -> Vec<&str> {
248 let mut names: Vec<&str> = self.extractors.keys().map(|s| s.as_str()).collect();
249 names.sort();
250 names
251 }
252}
253
254impl Default for LanguageRegistry {
255 fn default() -> Self {
256 Self::new()
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn registry_has_all_builtin_languages() {
266 let registry = LanguageRegistry::new();
267 let languages = registry.registered_languages();
268 for lang in Lang::all() {
269 assert!(
270 languages.contains(&lang.name()),
271 "missing builtin language: {}",
272 lang.name()
273 );
274 }
275 }
276
277 #[cfg(feature = "lang-rust")]
278 #[test]
279 fn rust_complexity_counting() {
280 let registry = LanguageRegistry::new();
281 let source = br#"fn example(x: i32) -> i32 {
282 if x > 0 {
283 return x;
284 }
285 for i in 0..x {
286 if i % 2 == 0 {
287 continue;
288 }
289 }
290 0
291}
292"#;
293 let metrics = registry.complexity_for_source(Lang::Rust, source).unwrap();
294 assert!(
295 metrics.branches >= 2,
296 "expected >=2 branches, got {}",
297 metrics.branches
298 );
299 assert!(
300 metrics.loops >= 1,
301 "expected >=1 loop, got {}",
302 metrics.loops
303 );
304 assert!(
305 metrics.returns >= 1,
306 "expected >=1 return, got {}",
307 metrics.returns
308 );
309 }
310
311 #[cfg(feature = "lang-python")]
312 #[test]
313 fn python_complexity_counting() {
314 let registry = LanguageRegistry::new();
315 let source = br#"def example(x):
316 if x > 0:
317 return x
318 for i in range(x):
319 if i % 2 == 0:
320 continue
321 return 0
322"#;
323 let metrics = registry
324 .complexity_for_source(Lang::Python, source)
325 .unwrap();
326 assert!(
327 metrics.branches >= 2,
328 "expected >=2 branches, got {}",
329 metrics.branches
330 );
331 assert!(
332 metrics.loops >= 1,
333 "expected >=1 loop, got {}",
334 metrics.loops
335 );
336 assert!(
337 metrics.returns >= 2,
338 "expected >=2 returns, got {}",
339 metrics.returns
340 );
341 }
342
343 #[cfg(feature = "lang-typescript")]
344 #[test]
345 fn typescript_complexity_counting() {
346 let registry = LanguageRegistry::new();
347 let source = br#"function example(x: number): number {
348 if (x > 0) {
349 return x;
350 }
351 for (let i = 0; i < x; i++) {
352 if (i % 2 === 0) continue;
353 }
354 return 0;
355}
356"#;
357 let metrics = registry
358 .complexity_for_source(Lang::TypeScript, source)
359 .unwrap();
360 assert!(
361 metrics.branches >= 2,
362 "expected >=2 branches, got {}",
363 metrics.branches
364 );
365 assert!(
366 metrics.loops >= 1,
367 "expected >=1 loop, got {}",
368 metrics.loops
369 );
370 assert!(
371 metrics.returns >= 2,
372 "expected >=2 returns, got {}",
373 metrics.returns
374 );
375 }
376
377 #[test]
378 fn total_complexity_sums_metrics() {
379 let metrics = ComplexityMetrics::from_raw_fields(3, 2, 1, 4, 0);
380 assert_eq!(metrics.total_complexity(), 6);
381 }
382
383 #[test]
384 fn extractor_for_extension_works() {
385 let registry = LanguageRegistry::new();
386 assert!(registry.extractor_for_extension("rs").is_some());
387 assert!(registry.extractor_for_extension("py").is_some());
388 assert!(registry.extractor_for_extension("xyz").is_none());
389 }
390}