swiftide_integrations/treesitter/
code_tree.rs

1//! Code parsing
2//!
3//! Extracts typed semantics from code.
4#![allow(dead_code)]
5use itertools::Itertools;
6use tree_sitter::{Parser, Query, QueryCursor, StreamingIterator as _, Tree};
7
8use anyhow::{Context as _, Result};
9use std::collections::HashSet;
10
11use crate::treesitter::queries::{
12    csharp, go, java, javascript, python, ruby, rust, solidity, typescript,
13};
14
15use super::SupportedLanguages;
16
17#[derive(Debug)]
18pub struct CodeParser {
19    language: SupportedLanguages,
20}
21
22impl CodeParser {
23    pub fn from_language(language: SupportedLanguages) -> Self {
24        Self { language }
25    }
26
27    /// Parses code and returns a `CodeTree`
28    ///
29    /// Tree-sitter is pretty lenient and will parse invalid code. I.e. if the code is invalid,
30    /// queries might fail and return no results.
31    ///
32    /// This is good as it makes this safe to use for chunked code as well.
33    ///
34    /// # Errors
35    ///
36    /// Errors if the language is not support or if the tree cannot be parsed
37    pub fn parse<'a>(&self, code: &'a str) -> Result<CodeTree<'a>> {
38        let mut parser = Parser::new();
39        parser.set_language(&self.language.into())?;
40        let ts_tree = parser.parse(code, None).context("No nodes found")?;
41
42        Ok(CodeTree {
43            ts_tree,
44            code,
45            language: self.language,
46        })
47    }
48}
49
50/// A code tree is a queryable representation of code
51pub struct CodeTree<'a> {
52    ts_tree: Tree,
53    code: &'a str,
54    language: SupportedLanguages,
55}
56
57pub struct ReferencesAndDefinitions {
58    pub references: Vec<String>,
59    pub definitions: Vec<String>,
60}
61
62impl CodeTree<'_> {
63    /// Queries for references and definitions in the code. It returns a unique list of non-local
64    /// references, and local definitions.
65    ///
66    /// # Errors
67    ///
68    /// Errors if the query is invalid or fails
69    pub fn references_and_definitions(&self) -> Result<ReferencesAndDefinitions> {
70        let (defs, refs) = ts_queries_for_language(self.language);
71
72        let defs_query = Query::new(&self.language.into(), defs)?;
73        let refs_query = Query::new(&self.language.into(), refs)?;
74
75        let defs = self.ts_query_for_matches(&defs_query)?;
76        let refs = self.ts_query_for_matches(&refs_query)?;
77
78        Ok(ReferencesAndDefinitions {
79            // Remove any self references
80            references: refs
81                .into_iter()
82                .filter(|r| !defs.contains(r))
83                .sorted()
84                .collect(),
85            definitions: defs.into_iter().sorted().collect(),
86        })
87    }
88
89    /// Given a `tree-sitter` query, searches the code and returns a list of matching symbols
90    fn ts_query_for_matches(&self, query: &Query) -> Result<HashSet<String>> {
91        let mut cursor = QueryCursor::new();
92
93        cursor
94            .matches(query, self.ts_tree.root_node(), self.code.as_bytes())
95            .map_deref(|m| {
96                m.captures
97                    .iter()
98                    .map(|c| {
99                        Ok(c.node
100                            .utf8_text(self.code.as_bytes())
101                            .context("Failed to parse node")?
102                            .to_string())
103                    })
104                    .collect::<Result<Vec<_>>>()
105                    .map(|s| s.join(""))
106            })
107            .collect::<Result<HashSet<_>>>()
108    }
109}
110
111fn ts_queries_for_language(language: SupportedLanguages) -> (&'static str, &'static str) {
112    use SupportedLanguages::{
113        C, CSharp, Cpp, Elixir, Go, HTML, Java, Javascript, PHP, Python, Ruby, Rust, Solidity,
114        Typescript,
115    };
116
117    match language {
118        Rust => (rust::DEFS, rust::REFS),
119        Python => (python::DEFS, python::REFS),
120        // The univocal proof that TS is just a linter
121        Typescript => (typescript::DEFS, typescript::REFS),
122        Javascript => (javascript::DEFS, javascript::REFS),
123        Ruby => (ruby::DEFS, ruby::REFS),
124        Java => (java::DEFS, java::REFS),
125        Go => (go::DEFS, go::REFS),
126        CSharp => (csharp::DEFS, csharp::REFS),
127        Solidity => (solidity::DEFS, solidity::REFS),
128        C | Cpp | Elixir | PHP | HTML => unimplemented!(),
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    #[test]
137    fn test_parsing_on_rust() {
138        let parser = CodeParser::from_language(SupportedLanguages::Rust);
139        let code = r#"
140        use std::io;
141
142        fn main() {
143            println!("Hello, world!");
144        }
145        "#;
146        let tree = parser.parse(code).unwrap();
147        let result = tree.references_and_definitions().unwrap();
148        assert_eq!(result.references, vec!["println"]);
149
150        assert_eq!(result.definitions, vec!["main"]);
151    }
152
153    #[test]
154    fn test_parsing_on_solidity() {
155        let parser = CodeParser::from_language(SupportedLanguages::Solidity);
156        let code = r"
157        pragma solidity ^0.8.0;
158
159        contract MyContract {
160            function myFunction() public {
161                emit MyEvent();
162            }
163        }
164        ";
165        let tree = parser.parse(code).unwrap();
166        let result = tree.references_and_definitions().unwrap();
167        assert_eq!(result.references, vec!["MyEvent"]);
168        assert_eq!(result.definitions, vec!["MyContract", "myFunction"]);
169    }
170
171    #[test]
172    fn test_parsing_on_ruby() {
173        let parser = CodeParser::from_language(SupportedLanguages::Ruby);
174        let code = r#"
175        class A < Inheritance
176          include ActuallyAlsoInheritance
177
178          def a
179            puts "A"
180          end
181        end
182        "#;
183
184        let tree = parser.parse(code).unwrap();
185        let result = tree.references_and_definitions().unwrap();
186        assert_eq!(
187            result.references,
188            ["ActuallyAlsoInheritance", "Inheritance", "include", "puts",]
189        );
190
191        assert_eq!(result.definitions, ["A", "a"]);
192    }
193
194    #[test]
195    fn test_parsing_python() {
196        // test with a python class and list comprehension
197        let parser = CodeParser::from_language(SupportedLanguages::Python);
198        let code = r#"
199        class A:
200            def __init__(self):
201                self.a = [x for x in range(10)]
202
203        def hello_world():
204            print("Hello, world!")
205        "#;
206        let tree = parser.parse(code).unwrap();
207        let result = tree.references_and_definitions().unwrap();
208        assert_eq!(result.references, ["print", "range"]);
209        assert_eq!(result.definitions, vec!["A", "hello_world"]);
210    }
211
212    #[test]
213    fn test_parsing_on_c_sharp() {
214        let parser = CodeParser::from_language(SupportedLanguages::CSharp);
215        let code = r#"
216        public class Greeter
217        {
218            public void SayHello()
219            {
220                System.Console.WriteLine("Hello, world!");
221            }
222        }
223        "#;
224
225        let tree = parser.parse(code).unwrap();
226        let result = tree.references_and_definitions().unwrap();
227
228        assert_eq!(result.references, vec!["WriteLine"]);
229        assert_eq!(result.definitions, vec!["Greeter", "SayHello"]);
230    }
231
232    #[test]
233    fn test_parsing_on_typescript() {
234        let parser = CodeParser::from_language(SupportedLanguages::Typescript);
235        let code = r#"
236        function Test() {
237            console.log("Hello, TypeScript!");
238            otherThing();
239        }
240
241        class MyClass {
242            constructor() {
243                let local = 5;
244                this.myMethod();
245            }
246
247            myMethod() {
248                console.log("Hello, TypeScript!");
249            }
250        }
251        "#;
252
253        let tree = parser.parse(code).unwrap();
254        let result = tree.references_and_definitions().unwrap();
255        assert_eq!(result.definitions, vec!["MyClass", "Test", "myMethod"]);
256        assert_eq!(result.references, vec!["log", "otherThing"]);
257    }
258
259    #[test]
260    fn test_parsing_on_javascript() {
261        let parser = CodeParser::from_language(SupportedLanguages::Javascript);
262        let code = r#"
263        function Test() {
264            console.log("Hello, JavaScript!");
265            otherThing();
266        }
267        class MyClass {
268            constructor() {
269                let local = 5;
270                this.myMethod();
271            }
272            myMethod() {
273                console.log("Hello, JavaScript!");
274            }
275        }
276        "#;
277        let tree = parser.parse(code).unwrap();
278        let result = tree.references_and_definitions().unwrap();
279        assert_eq!(result.definitions, vec!["MyClass", "Test", "myMethod"]);
280        assert_eq!(result.references, vec!["log", "otherThing"]);
281    }
282
283    #[test]
284    fn test_parsing_on_java() {
285        let parser = CodeParser::from_language(SupportedLanguages::Java);
286        let code = r#"
287        public class Hello {
288            public static void main(String[] args) {
289                System.out.printf("Hello %s!%n", args[0]);
290            }
291        }
292        "#;
293        let tree = parser.parse(code).unwrap();
294        let result = tree.references_and_definitions().unwrap();
295        assert_eq!(result.definitions, vec!["Hello", "main"]);
296        assert_eq!(result.references, vec!["printf"]);
297    }
298
299    #[test]
300    fn test_parsing_on_java_enum() {
301        let parser = CodeParser::from_language(SupportedLanguages::Java);
302        let code = r"
303        enum Material {
304            DENIM,
305            CANVAS,
306            SPANDEX_3_PERCENT
307        }
308
309        class Person {
310
311
312          Person(string name) {
313            this.name = name;
314
315            this.pants = new Pants<Pocket>();
316          }
317
318          String getName() {
319            a = this.name;
320            b = new one.two.Three();
321            c = Material.DENIM;
322          }
323        }
324        ";
325        let tree = parser.parse(code).unwrap();
326        let result = tree.references_and_definitions().unwrap();
327        assert_eq!(result.definitions, vec!["Material", "Person", "getName"]);
328        assert!(result.references.is_empty());
329    }
330
331    #[test]
332    fn test_parsing_go() {
333        let parser = CodeParser::from_language(SupportedLanguages::Go);
334        // hello world go with struct
335        let code = r"
336        package main
337
338        type Person struct {
339            name string
340            age int
341        }
342
343        func main() {
344            p := Person{name: 'John', age: 30}
345            fmt.Println(p)
346        }
347        ";
348
349        let tree = parser.parse(code).unwrap();
350        let result = tree.references_and_definitions().unwrap();
351        assert_eq!(result.references, vec!["Println", "int", "string"]);
352        assert_eq!(result.definitions, vec!["Person", "main"]);
353    }
354}