Skip to main content

sqry_lang_go/
lib.rs

1//! Go language plugin for sqry
2//!
3//! Implements the `LanguagePlugin` trait for Go, providing:
4//! - AST parsing with tree-sitter
5//! - Scope extraction
6//! - **Relation tracking** (calls/imports/exports); new semantics must flow through `sqry_core::graph::GraphBuilder` and `GoGraphBuilder` into `CodeGraph`
7
8pub mod relations;
9
10pub use relations::GoGraphBuilder;
11
12use sqry_core::ast::{Scope, ScopeId, link_nested_scopes};
13use sqry_core::plugin::{
14    LanguageMetadata, LanguagePlugin,
15    error::{ParseError, ScopeError},
16};
17use std::path::Path;
18use streaming_iterator::StreamingIterator;
19use tree_sitter::{Parser, Query, QueryCursor, Tree};
20
21/// Go language plugin
22pub struct GoPlugin {
23    graph_builder: GoGraphBuilder,
24}
25
26impl GoPlugin {
27    #[must_use]
28    pub fn new() -> Self {
29        Self {
30            graph_builder: GoGraphBuilder::default(),
31        }
32    }
33}
34
35impl Default for GoPlugin {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl LanguagePlugin for GoPlugin {
42    fn metadata(&self) -> LanguageMetadata {
43        LanguageMetadata {
44            id: "go",
45            name: "Go",
46            version: env!("CARGO_PKG_VERSION"),
47            author: "Verivus Pty Ltd",
48            description: "Go language support for sqry",
49            tree_sitter_version: "0.24",
50        }
51    }
52
53    fn extensions(&self) -> &'static [&'static str] {
54        &["go"]
55    }
56
57    fn language(&self) -> tree_sitter::Language {
58        tree_sitter_go::LANGUAGE.into()
59    }
60
61    fn parse_ast(&self, content: &[u8]) -> Result<Tree, ParseError> {
62        let mut parser = Parser::new();
63        parser
64            .set_language(&self.language())
65            .map_err(|e| ParseError::LanguageSetFailed(e.to_string()))?;
66
67        parser
68            .parse(content, None)
69            .ok_or(ParseError::TreeSitterFailed)
70    }
71
72    fn extract_scopes(
73        &self,
74        tree: &Tree,
75        content: &[u8],
76        file_path: &Path,
77    ) -> Result<Vec<Scope>, ScopeError> {
78        Self::extract_go_scopes(tree, content, file_path)
79    }
80
81    fn graph_builder(&self) -> Option<&dyn sqry_core::graph::GraphBuilder> {
82        Some(&self.graph_builder)
83    }
84}
85
86impl GoPlugin {
87    fn extract_go_scopes(
88        tree: &Tree,
89        content: &[u8],
90        file_path: &Path,
91    ) -> Result<Vec<Scope>, ScopeError> {
92        let root_node = tree.root_node();
93        let language = tree_sitter_go::LANGUAGE.into();
94
95        let scope_query = Self::scope_query_source();
96        let query = Query::new(&language, scope_query)
97            .map_err(|e| ScopeError::QueryCompilationFailed(e.to_string()))?;
98
99        let mut scopes = Vec::new();
100        let mut cursor = QueryCursor::new();
101        let mut query_matches = cursor.matches(&query, root_node, content);
102
103        while let Some(m) = query_matches.next() {
104            let mut scope_type = None;
105            let mut scope_name = None;
106            let mut scope_start = None;
107            let mut scope_end = None;
108
109            for capture in m.captures {
110                let capture_name = query.capture_names()[capture.index as usize];
111                let node = capture.node;
112
113                let capture_extension = std::path::Path::new(capture_name)
114                    .extension()
115                    .and_then(|ext| ext.to_str());
116                if capture_extension.is_some_and(|ext| ext.eq_ignore_ascii_case("type")) {
117                    scope_type = Some(capture_name.trim_end_matches(".type").to_string());
118                    scope_start = Some(node.start_position());
119                    scope_end = Some(node.end_position());
120                } else if capture_extension.is_some_and(|ext| ext.eq_ignore_ascii_case("name")) {
121                    scope_name = node
122                        .utf8_text(content)
123                        .ok()
124                        .map(std::string::ToString::to_string);
125                }
126            }
127
128            if let (Some(stype), Some(sname), Some(start), Some(end)) =
129                (scope_type, scope_name, scope_start, scope_end)
130            {
131                let scope = Scope {
132                    id: ScopeId::new(0),
133                    scope_type: stype,
134                    name: sname,
135                    file_path: file_path.to_path_buf(),
136                    start_line: start.row + 1,
137                    start_column: start.column,
138                    end_line: end.row + 1,
139                    end_column: end.column,
140                    parent_id: None,
141                };
142                scopes.push(scope);
143            }
144        }
145
146        scopes.sort_by_key(|s| (s.start_line, s.start_column));
147
148        link_nested_scopes(&mut scopes);
149        Ok(scopes)
150    }
151
152    fn scope_query_source() -> &'static str {
153        r"
154; Function scopes
155(function_declaration
156  name: (identifier) @function.name
157) @function.type
158
159; Method scopes
160(method_declaration
161  name: (field_identifier) @method.name
162) @method.type
163"
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use std::path::PathBuf;
171
172    #[test]
173    fn test_metadata() {
174        let plugin = GoPlugin::default();
175        let metadata = plugin.metadata();
176
177        assert_eq!(metadata.id, "go");
178        assert_eq!(metadata.name, "Go");
179    }
180
181    #[test]
182    fn test_extensions() {
183        let plugin = GoPlugin::default();
184        let extensions = plugin.extensions();
185
186        assert_eq!(extensions.len(), 1);
187        assert!(extensions.contains(&"go"));
188    }
189
190    #[test]
191    fn test_parse_ast_simple() {
192        let plugin = GoPlugin::default();
193        let source = b"package main\nfunc main() {}";
194
195        let tree = plugin.parse_ast(source).unwrap();
196        assert!(!tree.root_node().has_error());
197    }
198
199    #[test]
200    fn test_extract_scopes_simple() {
201        let plugin = GoPlugin::default();
202        let source = b"package main\nfunc hello() {}\nfunc world() int { return 42 }";
203        let file = PathBuf::from("test.go");
204
205        let tree = plugin.parse_ast(source).unwrap();
206        let scopes = plugin.extract_scopes(&tree, source, &file).unwrap();
207
208        assert!(scopes.iter().any(|s| s.name == "hello"));
209        assert!(scopes.iter().any(|s| s.name == "world"));
210    }
211}