Skip to main content

tsift_graph/
complexity.rs

1use anyhow::Result;
2use std::collections::HashMap;
3use tree_sitter::{Parser, Query, QueryCursor, StreamingIterator};
4
5use crate::lang::Lang;
6
7#[derive(Debug, Clone, Default)]
8pub struct ComplexityMetrics {
9    pub branches: i64,
10    pub loops: i64,
11    pub returns: i64,
12    pub max_nesting: i64,
13    pub unsafe_blocks: i64,
14}
15
16impl ComplexityMetrics {
17    pub fn total_complexity(&self) -> i64 {
18        self.branches + self.loops + self.returns
19    }
20
21    pub fn from_raw_fields(
22        branches: i64,
23        loops: i64,
24        returns: i64,
25        max_nesting: i64,
26        unsafe_blocks: i64,
27    ) -> Self {
28        Self {
29            branches,
30            loops,
31            returns,
32            max_nesting,
33            unsafe_blocks,
34        }
35    }
36}
37
38pub trait LanguageExtractor: Send + Sync {
39    fn lang(&self) -> Lang;
40    fn extract_complexity(&self, source: &[u8]) -> Result<ComplexityMetrics>;
41}
42
43struct BuiltinExtractor {
44    lang: Lang,
45}
46
47impl BuiltinExtractor {
48    fn complexity_query(&self) -> Option<&'static str> {
49        match self.lang {
50            #[cfg(feature = "lang-rust")]
51            Lang::Rust => Some(
52                r#"
53                (if_expression) @branch
54                (match_expression) @branch
55                (for_expression) @loop
56                (while_expression) @loop
57                (loop_expression) @loop
58                (return_expression) @return
59                (unsafe_block) @unsafe
60            "#,
61            ),
62            #[cfg(feature = "lang-python")]
63            Lang::Python => Some(
64                r#"
65                (if_statement) @branch
66                (elif_clause) @branch
67                (for_statement) @loop
68                (while_statement) @loop
69                (return_statement) @return
70            "#,
71            ),
72            #[cfg(feature = "lang-typescript")]
73            Lang::TypeScript | Lang::Tsx => Some(
74                r#"
75                (if_statement) @branch
76                (switch_statement) @branch
77                (ternary_expression) @branch
78                (for_statement) @loop
79                (for_in_statement) @loop
80                (while_statement) @loop
81                (do_statement) @loop
82                (return_statement) @return
83            "#,
84            ),
85            #[cfg(feature = "lang-javascript")]
86            Lang::JavaScript | Lang::Jsx => Some(
87                r#"
88                (if_statement) @branch
89                (switch_statement) @branch
90                (ternary_expression) @branch
91                (for_statement) @loop
92                (for_in_statement) @loop
93                (while_statement) @loop
94                (do_statement) @loop
95                (return_statement) @return
96            "#,
97            ),
98            #[cfg(feature = "lang-kotlin")]
99            Lang::Kotlin => Some(
100                r#"
101                (if_expression) @branch
102                (when_expression) @branch
103                (for_statement) @loop
104                (while_statement) @loop
105                (do_while_statement) @loop
106                (return_expression) @return
107            "#,
108            ),
109            _ => None,
110        }
111    }
112
113    fn compute_max_nesting(&self, source: &[u8]) -> i64 {
114        let ts_lang = self.lang.tree_sitter_language();
115        let mut parser = Parser::new();
116        if parser.set_language(&ts_lang).is_err() {
117            return 0;
118        }
119        let tree = match parser.parse(source, None) {
120            Some(t) => t,
121            None => return 0,
122        };
123        let mut max_depth: i64 = 0;
124        fn walk(node: tree_sitter::Node, depth: i64, max_depth: &mut i64) {
125            let kind = node.kind();
126            let is_scope = matches!(
127                kind,
128                "function_item"
129                    | "function_definition"
130                    | "function_declaration"
131                    | "class_definition"
132                    | "class_declaration"
133                    | "impl_item"
134                    | "if_expression"
135                    | "if_statement"
136                    | "for_expression"
137                    | "for_statement"
138                    | "while_expression"
139                    | "while_statement"
140                    | "loop_expression"
141                    | "match_expression"
142                    | "switch_statement"
143                    | "when_expression"
144                    | "block"
145                    | "expression_list"
146            );
147            let child_depth = if is_scope { depth + 1 } else { depth };
148            if child_depth > *max_depth {
149                *max_depth = child_depth;
150            }
151            let mut cursor = node.walk();
152            for child in node.children(&mut cursor) {
153                walk(child, child_depth, max_depth);
154            }
155        }
156        walk(tree.root_node(), 0, &mut max_depth);
157        max_depth.max(0)
158    }
159}
160
161impl LanguageExtractor for BuiltinExtractor {
162    fn lang(&self) -> Lang {
163        self.lang
164    }
165
166    fn extract_complexity(&self, source: &[u8]) -> Result<ComplexityMetrics> {
167        let query_str = match self.complexity_query() {
168            Some(q) => q,
169            None => return Ok(ComplexityMetrics::default()),
170        };
171        let ts_lang = self.lang.tree_sitter_language();
172        let mut parser = Parser::new();
173        parser.set_language(&ts_lang)?;
174        let tree = parser
175            .parse(source, None)
176            .ok_or_else(|| anyhow::anyhow!("parse failed"))?;
177        let query = Query::new(&ts_lang, query_str)?;
178        let mut cursor = QueryCursor::new();
179        let mut metrics = ComplexityMetrics::default();
180
181        let capture_names: Vec<String> = query
182            .capture_names()
183            .iter()
184            .map(|s| s.to_string())
185            .collect();
186
187        let mut matches = cursor.matches(&query, tree.root_node(), source);
188        while let Some(m) = matches.next() {
189            for capture in m.captures {
190                let name = &capture_names[capture.index as usize];
191                match name.as_str() {
192                    "branch" => metrics.branches += 1,
193                    "loop" => metrics.loops += 1,
194                    "return" => metrics.returns += 1,
195                    "unsafe" => metrics.unsafe_blocks += 1,
196                    _ => {}
197                }
198            }
199        }
200
201        metrics.max_nesting = self.compute_max_nesting(source);
202        Ok(metrics)
203    }
204}
205
206pub struct LanguageRegistry {
207    extractors: HashMap<String, Box<dyn LanguageExtractor>>,
208}
209
210impl LanguageRegistry {
211    pub fn new() -> Self {
212        let mut registry = Self {
213            extractors: HashMap::new(),
214        };
215        registry.register_builtins();
216        registry
217    }
218
219    fn register_builtins(&mut self) {
220        for lang in Lang::all() {
221            let ext = lang.name().to_string();
222            let extractor = BuiltinExtractor { lang };
223            self.extractors.insert(ext, Box::new(extractor));
224        }
225    }
226
227    pub fn register(&mut self, name: String, extractor: Box<dyn LanguageExtractor>) {
228        self.extractors.insert(name, extractor);
229    }
230
231    pub fn get(&self, lang_name: &str) -> Option<&dyn LanguageExtractor> {
232        self.extractors.get(lang_name).map(|e| e.as_ref())
233    }
234
235    pub fn extractor_for_extension(&self, ext: &str) -> Option<&dyn LanguageExtractor> {
236        let lang = Lang::from_extension(ext)?;
237        self.get(lang.name())
238    }
239
240    pub fn complexity_for_source(&self, lang: Lang, source: &[u8]) -> Result<ComplexityMetrics> {
241        let extractor = self.get(lang.name()).ok_or_else(|| {
242            anyhow::anyhow!("no extractor registered for language: {}", lang.name())
243        })?;
244        extractor.extract_complexity(source)
245    }
246
247    pub fn registered_languages(&self) -> Vec<&str> {
248        let mut names: Vec<&str> = self.extractors.keys().map(|s| s.as_str()).collect();
249        names.sort();
250        names
251    }
252}
253
254impl Default for LanguageRegistry {
255    fn default() -> Self {
256        Self::new()
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    #[test]
265    fn registry_has_all_builtin_languages() {
266        let registry = LanguageRegistry::new();
267        let languages = registry.registered_languages();
268        for lang in Lang::all() {
269            assert!(
270                languages.contains(&lang.name()),
271                "missing builtin language: {}",
272                lang.name()
273            );
274        }
275    }
276
277    #[cfg(feature = "lang-rust")]
278    #[test]
279    fn rust_complexity_counting() {
280        let registry = LanguageRegistry::new();
281        let source = br#"fn example(x: i32) -> i32 {
282    if x > 0 {
283        return x;
284    }
285    for i in 0..x {
286        if i % 2 == 0 {
287            continue;
288        }
289    }
290    0
291}
292"#;
293        let metrics = registry.complexity_for_source(Lang::Rust, source).unwrap();
294        assert!(
295            metrics.branches >= 2,
296            "expected >=2 branches, got {}",
297            metrics.branches
298        );
299        assert!(
300            metrics.loops >= 1,
301            "expected >=1 loop, got {}",
302            metrics.loops
303        );
304        assert!(
305            metrics.returns >= 1,
306            "expected >=1 return, got {}",
307            metrics.returns
308        );
309    }
310
311    #[cfg(feature = "lang-python")]
312    #[test]
313    fn python_complexity_counting() {
314        let registry = LanguageRegistry::new();
315        let source = br#"def example(x):
316    if x > 0:
317        return x
318    for i in range(x):
319        if i % 2 == 0:
320            continue
321    return 0
322"#;
323        let metrics = registry
324            .complexity_for_source(Lang::Python, source)
325            .unwrap();
326        assert!(
327            metrics.branches >= 2,
328            "expected >=2 branches, got {}",
329            metrics.branches
330        );
331        assert!(
332            metrics.loops >= 1,
333            "expected >=1 loop, got {}",
334            metrics.loops
335        );
336        assert!(
337            metrics.returns >= 2,
338            "expected >=2 returns, got {}",
339            metrics.returns
340        );
341    }
342
343    #[cfg(feature = "lang-typescript")]
344    #[test]
345    fn typescript_complexity_counting() {
346        let registry = LanguageRegistry::new();
347        let source = br#"function example(x: number): number {
348    if (x > 0) {
349        return x;
350    }
351    for (let i = 0; i < x; i++) {
352        if (i % 2 === 0) continue;
353    }
354    return 0;
355}
356"#;
357        let metrics = registry
358            .complexity_for_source(Lang::TypeScript, source)
359            .unwrap();
360        assert!(
361            metrics.branches >= 2,
362            "expected >=2 branches, got {}",
363            metrics.branches
364        );
365        assert!(
366            metrics.loops >= 1,
367            "expected >=1 loop, got {}",
368            metrics.loops
369        );
370        assert!(
371            metrics.returns >= 2,
372            "expected >=2 returns, got {}",
373            metrics.returns
374        );
375    }
376
377    #[test]
378    fn total_complexity_sums_metrics() {
379        let metrics = ComplexityMetrics::from_raw_fields(3, 2, 1, 4, 0);
380        assert_eq!(metrics.total_complexity(), 6);
381    }
382
383    #[test]
384    fn extractor_for_extension_works() {
385        let registry = LanguageRegistry::new();
386        assert!(registry.extractor_for_extension("rs").is_some());
387        assert!(registry.extractor_for_extension("py").is_some());
388        assert!(registry.extractor_for_extension("xyz").is_none());
389    }
390}