swiftide_integrations/treesitter/
code_tree.rs1#![allow(dead_code)]
5use itertools::Itertools;
6use tree_sitter::{Parser, Query, QueryCursor, StreamingIterator as _, Tree};
7
8use anyhow::{Context as _, Result};
9use std::collections::HashSet;
10
11use crate::treesitter::queries::{
12 csharp, go, java, javascript, python, ruby, rust, solidity, typescript,
13};
14
15use super::SupportedLanguages;
16
17#[derive(Debug)]
18pub struct CodeParser {
19 language: SupportedLanguages,
20}
21
22impl CodeParser {
23 pub fn from_language(language: SupportedLanguages) -> Self {
24 Self { language }
25 }
26
27 pub fn parse<'a>(&self, code: &'a str) -> Result<CodeTree<'a>> {
38 let mut parser = Parser::new();
39 parser.set_language(&self.language.into())?;
40 let ts_tree = parser.parse(code, None).context("No nodes found")?;
41
42 Ok(CodeTree {
43 ts_tree,
44 code,
45 language: self.language,
46 })
47 }
48}
49
50pub struct CodeTree<'a> {
52 ts_tree: Tree,
53 code: &'a str,
54 language: SupportedLanguages,
55}
56
57pub struct ReferencesAndDefinitions {
58 pub references: Vec<String>,
59 pub definitions: Vec<String>,
60}
61
62impl CodeTree<'_> {
63 pub fn references_and_definitions(&self) -> Result<ReferencesAndDefinitions> {
70 let (defs, refs) = ts_queries_for_language(self.language);
71
72 let defs_query = Query::new(&self.language.into(), defs)?;
73 let refs_query = Query::new(&self.language.into(), refs)?;
74
75 let defs = self.ts_query_for_matches(&defs_query)?;
76 let refs = self.ts_query_for_matches(&refs_query)?;
77
78 Ok(ReferencesAndDefinitions {
79 references: refs
81 .into_iter()
82 .filter(|r| !defs.contains(r))
83 .sorted()
84 .collect(),
85 definitions: defs.into_iter().sorted().collect(),
86 })
87 }
88
89 fn ts_query_for_matches(&self, query: &Query) -> Result<HashSet<String>> {
91 let mut cursor = QueryCursor::new();
92
93 cursor
94 .matches(query, self.ts_tree.root_node(), self.code.as_bytes())
95 .map_deref(|m| {
96 m.captures
97 .iter()
98 .map(|c| {
99 Ok(c.node
100 .utf8_text(self.code.as_bytes())
101 .context("Failed to parse node")?
102 .to_string())
103 })
104 .collect::<Result<Vec<_>>>()
105 .map(|s| s.join(""))
106 })
107 .collect::<Result<HashSet<_>>>()
108 }
109}
110
111fn ts_queries_for_language(language: SupportedLanguages) -> (&'static str, &'static str) {
112 use SupportedLanguages::{
113 C, CSharp, Cpp, Elixir, Go, HTML, Java, Javascript, PHP, Python, Ruby, Rust, Solidity,
114 Typescript,
115 };
116
117 match language {
118 Rust => (rust::DEFS, rust::REFS),
119 Python => (python::DEFS, python::REFS),
120 Typescript => (typescript::DEFS, typescript::REFS),
122 Javascript => (javascript::DEFS, javascript::REFS),
123 Ruby => (ruby::DEFS, ruby::REFS),
124 Java => (java::DEFS, java::REFS),
125 Go => (go::DEFS, go::REFS),
126 CSharp => (csharp::DEFS, csharp::REFS),
127 Solidity => (solidity::DEFS, solidity::REFS),
128 C | Cpp | Elixir | PHP | HTML => unimplemented!(),
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135
136 #[test]
137 fn test_parsing_on_rust() {
138 let parser = CodeParser::from_language(SupportedLanguages::Rust);
139 let code = r#"
140 use std::io;
141
142 fn main() {
143 println!("Hello, world!");
144 }
145 "#;
146 let tree = parser.parse(code).unwrap();
147 let result = tree.references_and_definitions().unwrap();
148 assert_eq!(result.references, vec!["println"]);
149
150 assert_eq!(result.definitions, vec!["main"]);
151 }
152
153 #[test]
154 fn test_parsing_on_solidity() {
155 let parser = CodeParser::from_language(SupportedLanguages::Solidity);
156 let code = r"
157 pragma solidity ^0.8.0;
158
159 contract MyContract {
160 function myFunction() public {
161 emit MyEvent();
162 }
163 }
164 ";
165 let tree = parser.parse(code).unwrap();
166 let result = tree.references_and_definitions().unwrap();
167 assert_eq!(result.references, vec!["MyEvent"]);
168 assert_eq!(result.definitions, vec!["MyContract", "myFunction"]);
169 }
170
171 #[test]
172 fn test_parsing_on_ruby() {
173 let parser = CodeParser::from_language(SupportedLanguages::Ruby);
174 let code = r#"
175 class A < Inheritance
176 include ActuallyAlsoInheritance
177
178 def a
179 puts "A"
180 end
181 end
182 "#;
183
184 let tree = parser.parse(code).unwrap();
185 let result = tree.references_and_definitions().unwrap();
186 assert_eq!(
187 result.references,
188 ["ActuallyAlsoInheritance", "Inheritance", "include", "puts",]
189 );
190
191 assert_eq!(result.definitions, ["A", "a"]);
192 }
193
194 #[test]
195 fn test_parsing_python() {
196 let parser = CodeParser::from_language(SupportedLanguages::Python);
198 let code = r#"
199 class A:
200 def __init__(self):
201 self.a = [x for x in range(10)]
202
203 def hello_world():
204 print("Hello, world!")
205 "#;
206 let tree = parser.parse(code).unwrap();
207 let result = tree.references_and_definitions().unwrap();
208 assert_eq!(result.references, ["print", "range"]);
209 assert_eq!(result.definitions, vec!["A", "hello_world"]);
210 }
211
212 #[test]
213 fn test_parsing_on_c_sharp() {
214 let parser = CodeParser::from_language(SupportedLanguages::CSharp);
215 let code = r#"
216 public class Greeter
217 {
218 public void SayHello()
219 {
220 System.Console.WriteLine("Hello, world!");
221 }
222 }
223 "#;
224
225 let tree = parser.parse(code).unwrap();
226 let result = tree.references_and_definitions().unwrap();
227
228 assert_eq!(result.references, vec!["WriteLine"]);
229 assert_eq!(result.definitions, vec!["Greeter", "SayHello"]);
230 }
231
232 #[test]
233 fn test_parsing_on_typescript() {
234 let parser = CodeParser::from_language(SupportedLanguages::Typescript);
235 let code = r#"
236 function Test() {
237 console.log("Hello, TypeScript!");
238 otherThing();
239 }
240
241 class MyClass {
242 constructor() {
243 let local = 5;
244 this.myMethod();
245 }
246
247 myMethod() {
248 console.log("Hello, TypeScript!");
249 }
250 }
251 "#;
252
253 let tree = parser.parse(code).unwrap();
254 let result = tree.references_and_definitions().unwrap();
255 assert_eq!(result.definitions, vec!["MyClass", "Test", "myMethod"]);
256 assert_eq!(result.references, vec!["log", "otherThing"]);
257 }
258
259 #[test]
260 fn test_parsing_on_javascript() {
261 let parser = CodeParser::from_language(SupportedLanguages::Javascript);
262 let code = r#"
263 function Test() {
264 console.log("Hello, JavaScript!");
265 otherThing();
266 }
267 class MyClass {
268 constructor() {
269 let local = 5;
270 this.myMethod();
271 }
272 myMethod() {
273 console.log("Hello, JavaScript!");
274 }
275 }
276 "#;
277 let tree = parser.parse(code).unwrap();
278 let result = tree.references_and_definitions().unwrap();
279 assert_eq!(result.definitions, vec!["MyClass", "Test", "myMethod"]);
280 assert_eq!(result.references, vec!["log", "otherThing"]);
281 }
282
283 #[test]
284 fn test_parsing_on_java() {
285 let parser = CodeParser::from_language(SupportedLanguages::Java);
286 let code = r#"
287 public class Hello {
288 public static void main(String[] args) {
289 System.out.printf("Hello %s!%n", args[0]);
290 }
291 }
292 "#;
293 let tree = parser.parse(code).unwrap();
294 let result = tree.references_and_definitions().unwrap();
295 assert_eq!(result.definitions, vec!["Hello", "main"]);
296 assert_eq!(result.references, vec!["printf"]);
297 }
298
299 #[test]
300 fn test_parsing_on_java_enum() {
301 let parser = CodeParser::from_language(SupportedLanguages::Java);
302 let code = r"
303 enum Material {
304 DENIM,
305 CANVAS,
306 SPANDEX_3_PERCENT
307 }
308
309 class Person {
310
311
312 Person(string name) {
313 this.name = name;
314
315 this.pants = new Pants<Pocket>();
316 }
317
318 String getName() {
319 a = this.name;
320 b = new one.two.Three();
321 c = Material.DENIM;
322 }
323 }
324 ";
325 let tree = parser.parse(code).unwrap();
326 let result = tree.references_and_definitions().unwrap();
327 assert_eq!(result.definitions, vec!["Material", "Person", "getName"]);
328 assert!(result.references.is_empty());
329 }
330
331 #[test]
332 fn test_parsing_go() {
333 let parser = CodeParser::from_language(SupportedLanguages::Go);
334 let code = r"
336 package main
337
338 type Person struct {
339 name string
340 age int
341 }
342
343 func main() {
344 p := Person{name: 'John', age: 30}
345 fmt.Println(p)
346 }
347 ";
348
349 let tree = parser.parse(code).unwrap();
350 let result = tree.references_and_definitions().unwrap();
351 assert_eq!(result.references, vec!["Println", "int", "string"]);
352 assert_eq!(result.definitions, vec!["Person", "main"]);
353 }
354}