1pub mod relations;
9
10pub use relations::GoGraphBuilder;
11
12use sqry_core::ast::{Scope, ScopeId, link_nested_scopes};
13use sqry_core::plugin::{
14 LanguageMetadata, LanguagePlugin,
15 error::{ParseError, ScopeError},
16};
17use std::path::Path;
18use streaming_iterator::StreamingIterator;
19use tree_sitter::{Parser, Query, QueryCursor, Tree};
20
21pub struct GoPlugin {
23 graph_builder: GoGraphBuilder,
24}
25
26impl GoPlugin {
27 #[must_use]
28 pub fn new() -> Self {
29 Self {
30 graph_builder: GoGraphBuilder::default(),
31 }
32 }
33}
34
35impl Default for GoPlugin {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl LanguagePlugin for GoPlugin {
42 fn metadata(&self) -> LanguageMetadata {
43 LanguageMetadata {
44 id: "go",
45 name: "Go",
46 version: env!("CARGO_PKG_VERSION"),
47 author: "Verivus Pty Ltd",
48 description: "Go language support for sqry",
49 tree_sitter_version: "0.24",
50 }
51 }
52
53 fn extensions(&self) -> &'static [&'static str] {
54 &["go"]
55 }
56
57 fn language(&self) -> tree_sitter::Language {
58 tree_sitter_go::LANGUAGE.into()
59 }
60
61 fn parse_ast(&self, content: &[u8]) -> Result<Tree, ParseError> {
62 let mut parser = Parser::new();
63 parser
64 .set_language(&self.language())
65 .map_err(|e| ParseError::LanguageSetFailed(e.to_string()))?;
66
67 parser
68 .parse(content, None)
69 .ok_or(ParseError::TreeSitterFailed)
70 }
71
72 fn extract_scopes(
73 &self,
74 tree: &Tree,
75 content: &[u8],
76 file_path: &Path,
77 ) -> Result<Vec<Scope>, ScopeError> {
78 Self::extract_go_scopes(tree, content, file_path)
79 }
80
81 fn graph_builder(&self) -> Option<&dyn sqry_core::graph::GraphBuilder> {
82 Some(&self.graph_builder)
83 }
84}
85
86impl GoPlugin {
87 fn extract_go_scopes(
88 tree: &Tree,
89 content: &[u8],
90 file_path: &Path,
91 ) -> Result<Vec<Scope>, ScopeError> {
92 let root_node = tree.root_node();
93 let language = tree_sitter_go::LANGUAGE.into();
94
95 let scope_query = Self::scope_query_source();
96 let query = Query::new(&language, scope_query)
97 .map_err(|e| ScopeError::QueryCompilationFailed(e.to_string()))?;
98
99 let mut scopes = Vec::new();
100 let mut cursor = QueryCursor::new();
101 let mut query_matches = cursor.matches(&query, root_node, content);
102
103 while let Some(m) = query_matches.next() {
104 let mut scope_type = None;
105 let mut scope_name = None;
106 let mut scope_start = None;
107 let mut scope_end = None;
108
109 for capture in m.captures {
110 let capture_name = query.capture_names()[capture.index as usize];
111 let node = capture.node;
112
113 let capture_extension = std::path::Path::new(capture_name)
114 .extension()
115 .and_then(|ext| ext.to_str());
116 if capture_extension.is_some_and(|ext| ext.eq_ignore_ascii_case("type")) {
117 scope_type = Some(capture_name.trim_end_matches(".type").to_string());
118 scope_start = Some(node.start_position());
119 scope_end = Some(node.end_position());
120 } else if capture_extension.is_some_and(|ext| ext.eq_ignore_ascii_case("name")) {
121 scope_name = node
122 .utf8_text(content)
123 .ok()
124 .map(std::string::ToString::to_string);
125 }
126 }
127
128 if let (Some(stype), Some(sname), Some(start), Some(end)) =
129 (scope_type, scope_name, scope_start, scope_end)
130 {
131 let scope = Scope {
132 id: ScopeId::new(0),
133 scope_type: stype,
134 name: sname,
135 file_path: file_path.to_path_buf(),
136 start_line: start.row + 1,
137 start_column: start.column,
138 end_line: end.row + 1,
139 end_column: end.column,
140 parent_id: None,
141 };
142 scopes.push(scope);
143 }
144 }
145
146 scopes.sort_by_key(|s| (s.start_line, s.start_column));
147
148 link_nested_scopes(&mut scopes);
149 Ok(scopes)
150 }
151
152 fn scope_query_source() -> &'static str {
153 r"
154; Function scopes
155(function_declaration
156 name: (identifier) @function.name
157) @function.type
158
159; Method scopes
160(method_declaration
161 name: (field_identifier) @method.name
162) @method.type
163"
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use std::path::PathBuf;
171
172 #[test]
173 fn test_metadata() {
174 let plugin = GoPlugin::default();
175 let metadata = plugin.metadata();
176
177 assert_eq!(metadata.id, "go");
178 assert_eq!(metadata.name, "Go");
179 }
180
181 #[test]
182 fn test_extensions() {
183 let plugin = GoPlugin::default();
184 let extensions = plugin.extensions();
185
186 assert_eq!(extensions.len(), 1);
187 assert!(extensions.contains(&"go"));
188 }
189
190 #[test]
191 fn test_parse_ast_simple() {
192 let plugin = GoPlugin::default();
193 let source = b"package main\nfunc main() {}";
194
195 let tree = plugin.parse_ast(source).unwrap();
196 assert!(!tree.root_node().has_error());
197 }
198
199 #[test]
200 fn test_extract_scopes_simple() {
201 let plugin = GoPlugin::default();
202 let source = b"package main\nfunc hello() {}\nfunc world() int { return 42 }";
203 let file = PathBuf::from("test.go");
204
205 let tree = plugin.parse_ast(source).unwrap();
206 let scopes = plugin.extract_scopes(&tree, source, &file).unwrap();
207
208 assert!(scopes.iter().any(|s| s.name == "hello"));
209 assert!(scopes.iter().any(|s| s.name == "world"));
210 }
211}