swiftide_integrations/treesitter/
code_tree.rs

1//! Code parsing
2//!
3//! Extracts typed semantics from code.
4#![allow(dead_code)]
5use itertools::Itertools;
6use tree_sitter::{Parser, Query, QueryCursor, StreamingIterator as _, Tree};
7
8use anyhow::{Context as _, Result};
9use std::collections::HashSet;
10
11use crate::treesitter::queries::{go, java, javascript, python, ruby, rust, solidity, typescript};
12
13use super::SupportedLanguages;
14
15#[derive(Debug)]
16pub struct CodeParser {
17    language: SupportedLanguages,
18}
19
20impl CodeParser {
21    pub fn from_language(language: SupportedLanguages) -> Self {
22        Self { language }
23    }
24
25    /// Parses code and returns a `CodeTree`
26    ///
27    /// Tree-sitter is pretty lenient and will parse invalid code. I.e. if the code is invalid,
28    /// queries might fail and return no results.
29    ///
30    /// This is good as it makes this safe to use for chunked code as well.
31    ///
32    /// # Errors
33    ///
34    /// Errors if the language is not support or if the tree cannot be parsed
35    pub fn parse<'a>(&self, code: &'a str) -> Result<CodeTree<'a>> {
36        let mut parser = Parser::new();
37        parser.set_language(&self.language.into())?;
38        let ts_tree = parser.parse(code, None).context("No nodes found")?;
39
40        Ok(CodeTree {
41            ts_tree,
42            code,
43            language: self.language,
44        })
45    }
46}
47
48/// A code tree is a queryable representation of code
49pub struct CodeTree<'a> {
50    ts_tree: Tree,
51    code: &'a str,
52    language: SupportedLanguages,
53}
54
55pub struct ReferencesAndDefinitions {
56    pub references: Vec<String>,
57    pub definitions: Vec<String>,
58}
59
60impl CodeTree<'_> {
61    /// Queries for references and definitions in the code. It returns a unique list of non-local
62    /// references, and local definitions.
63    ///
64    /// # Errors
65    ///
66    /// Errors if the query is invalid or fails
67    pub fn references_and_definitions(&self) -> Result<ReferencesAndDefinitions> {
68        let (defs, refs) = ts_queries_for_language(self.language);
69
70        let defs_query = Query::new(&self.language.into(), defs)?;
71        let refs_query = Query::new(&self.language.into(), refs)?;
72
73        let defs = self.ts_query_for_matches(&defs_query)?;
74        let refs = self.ts_query_for_matches(&refs_query)?;
75
76        Ok(ReferencesAndDefinitions {
77            // Remove any self references
78            references: refs
79                .into_iter()
80                .filter(|r| !defs.contains(r))
81                .sorted()
82                .collect(),
83            definitions: defs.into_iter().sorted().collect(),
84        })
85    }
86
87    /// Given a `tree-sitter` query, searches the code and returns a list of matching symbols
88    fn ts_query_for_matches(&self, query: &Query) -> Result<HashSet<String>> {
89        let mut cursor = QueryCursor::new();
90
91        cursor
92            .matches(query, self.ts_tree.root_node(), self.code.as_bytes())
93            .map_deref(|m| {
94                m.captures
95                    .iter()
96                    .map(|c| {
97                        Ok(c.node
98                            .utf8_text(self.code.as_bytes())
99                            .context("Failed to parse node")?
100                            .to_string())
101                    })
102                    .collect::<Result<Vec<_>>>()
103                    .map(|s| s.join(""))
104            })
105            .collect::<Result<HashSet<_>>>()
106    }
107}
108
109fn ts_queries_for_language(language: SupportedLanguages) -> (&'static str, &'static str) {
110    use SupportedLanguages::{
111        C, Cpp, Elixir, Go, HTML, Java, Javascript, PHP, Python, Ruby, Rust, Solidity, Typescript,
112    };
113
114    match language {
115        Rust => (rust::DEFS, rust::REFS),
116        Python => (python::DEFS, python::REFS),
117        // The univocal proof that TS is just a linter
118        Typescript => (typescript::DEFS, typescript::REFS),
119        Javascript => (javascript::DEFS, javascript::REFS),
120        Ruby => (ruby::DEFS, ruby::REFS),
121        Java => (java::DEFS, java::REFS),
122        Go => (go::DEFS, go::REFS),
123        Solidity => (solidity::DEFS, solidity::REFS),
124        C | Cpp | Elixir | PHP | HTML => unimplemented!(),
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[test]
133    fn test_parsing_on_rust() {
134        let parser = CodeParser::from_language(SupportedLanguages::Rust);
135        let code = r#"
136        use std::io;
137
138        fn main() {
139            println!("Hello, world!");
140        }
141        "#;
142        let tree = parser.parse(code).unwrap();
143        let result = tree.references_and_definitions().unwrap();
144        assert_eq!(result.references, vec!["println"]);
145
146        assert_eq!(result.definitions, vec!["main"]);
147    }
148
149    #[test]
150    fn test_parsing_on_solidity() {
151        let parser = CodeParser::from_language(SupportedLanguages::Solidity);
152        let code = r"
153        pragma solidity ^0.8.0;
154
155        contract MyContract {
156            function myFunction() public {
157                emit MyEvent();
158            }
159        }
160        ";
161        let tree = parser.parse(code).unwrap();
162        let result = tree.references_and_definitions().unwrap();
163        assert_eq!(result.references, vec!["MyEvent"]);
164        assert_eq!(result.definitions, vec!["MyContract", "myFunction"]);
165    }
166
167    #[test]
168    fn test_parsing_on_ruby() {
169        let parser = CodeParser::from_language(SupportedLanguages::Ruby);
170        let code = r#"
171        class A < Inheritance
172          include ActuallyAlsoInheritance
173
174          def a
175            puts "A"
176          end
177        end
178        "#;
179
180        let tree = parser.parse(code).unwrap();
181        let result = tree.references_and_definitions().unwrap();
182        assert_eq!(
183            result.references,
184            ["ActuallyAlsoInheritance", "Inheritance", "include", "puts",]
185        );
186
187        assert_eq!(result.definitions, ["A", "a"]);
188    }
189
190    #[test]
191    fn test_parsing_python() {
192        // test with a python class and list comprehension
193        let parser = CodeParser::from_language(SupportedLanguages::Python);
194        let code = r#"
195        class A:
196            def __init__(self):
197                self.a = [x for x in range(10)]
198
199        def hello_world():
200            print("Hello, world!")
201        "#;
202        let tree = parser.parse(code).unwrap();
203        let result = tree.references_and_definitions().unwrap();
204        assert_eq!(result.references, ["print", "range"]);
205        assert_eq!(result.definitions, vec!["A", "hello_world"]);
206    }
207
208    #[test]
209    fn test_parsing_on_typescript() {
210        let parser = CodeParser::from_language(SupportedLanguages::Typescript);
211        let code = r#"
212        function Test() {
213            console.log("Hello, TypeScript!");
214            otherThing();
215        }
216
217        class MyClass {
218            constructor() {
219                let local = 5;
220                this.myMethod();
221            }
222
223            myMethod() {
224                console.log("Hello, TypeScript!");
225            }
226        }
227        "#;
228
229        let tree = parser.parse(code).unwrap();
230        let result = tree.references_and_definitions().unwrap();
231        assert_eq!(result.definitions, vec!["MyClass", "Test", "myMethod"]);
232        assert_eq!(result.references, vec!["log", "otherThing"]);
233    }
234
235    #[test]
236    fn test_parsing_on_javascript() {
237        let parser = CodeParser::from_language(SupportedLanguages::Javascript);
238        let code = r#"
239        function Test() {
240            console.log("Hello, JavaScript!");
241            otherThing();
242        }
243        class MyClass {
244            constructor() {
245                let local = 5;
246                this.myMethod();
247            }
248            myMethod() {
249                console.log("Hello, JavaScript!");
250            }
251        }
252        "#;
253        let tree = parser.parse(code).unwrap();
254        let result = tree.references_and_definitions().unwrap();
255        assert_eq!(result.definitions, vec!["MyClass", "Test", "myMethod"]);
256        assert_eq!(result.references, vec!["log", "otherThing"]);
257    }
258
259    #[test]
260    fn test_parsing_on_java() {
261        let parser = CodeParser::from_language(SupportedLanguages::Java);
262        let code = r#"
263        public class Hello {
264            public static void main(String[] args) {
265                System.out.printf("Hello %s!%n", args[0]);
266            }
267        }
268        "#;
269        let tree = parser.parse(code).unwrap();
270        let result = tree.references_and_definitions().unwrap();
271        assert_eq!(result.definitions, vec!["Hello", "main"]);
272        assert_eq!(result.references, vec!["printf"]);
273    }
274
275    #[test]
276    fn test_parsing_on_java_enum() {
277        let parser = CodeParser::from_language(SupportedLanguages::Java);
278        let code = r"
279        enum Material {
280            DENIM,
281            CANVAS,
282            SPANDEX_3_PERCENT
283        }
284
285        class Person {
286
287
288          Person(string name) {
289            this.name = name;
290
291            this.pants = new Pants<Pocket>();
292          }
293
294          String getName() {
295            a = this.name;
296            b = new one.two.Three();
297            c = Material.DENIM;
298          }
299        }
300        ";
301        let tree = parser.parse(code).unwrap();
302        let result = tree.references_and_definitions().unwrap();
303        assert_eq!(result.definitions, vec!["Material", "Person", "getName"]);
304        assert!(result.references.is_empty());
305    }
306
307    #[test]
308    fn test_parsing_go() {
309        let parser = CodeParser::from_language(SupportedLanguages::Go);
310        // hello world go with struct
311        let code = r"
312        package main
313
314        type Person struct {
315            name string
316            age int
317        }
318
319        func main() {
320            p := Person{name: 'John', age: 30}
321            fmt.Println(p)
322        }
323        ";
324
325        let tree = parser.parse(code).unwrap();
326        let result = tree.references_and_definitions().unwrap();
327        assert_eq!(result.references, vec!["Println", "int", "string"]);
328        assert_eq!(result.definitions, vec!["Person", "main"]);
329    }
330}