Skip to main content

sqry_lang_shell/
lib.rs

1//! Shell language plugin for sqry.
2//!
3//! Provides AST parsing, scope extraction, and graph building for POSIX shell
4//! and bash scripts.
5
6pub mod relations;
7
8pub use relations::ShellGraphBuilder;
9
10use sqry_core::ast::{Scope, ScopeId, link_nested_scopes};
11use sqry_core::graph::unified::edge::EdgeKind;
12use sqry_core::metadata::keys as metadata_keys;
13use sqry_core::plugin::{
14    LanguageMetadata, LanguagePlugin, PluginResult,
15    error::{ParseError, ScopeError},
16};
17use sqry_core::query::results::QueryMatch;
18use sqry_core::query::types::{FieldDescriptor, FieldType, Operator, Value};
19use std::fs;
20use std::path::Path;
21use tree_sitter::{Language, Parser, Query, QueryCursor, StreamingIterator, Tree};
22
23const LANGUAGE_ID: &str = "shell";
24const LANGUAGE_NAME: &str = "Shell";
25const TREE_SITTER_VERSION: &str = "0.23";
26
27/// Shell language plugin implementation.
28pub struct ShellPlugin {
29    graph_builder: ShellGraphBuilder,
30}
31
32impl ShellPlugin {
33    /// Creates a new Shell plugin instance.
34    #[must_use]
35    pub fn new() -> Self {
36        Self {
37            graph_builder: ShellGraphBuilder::default(),
38        }
39    }
40
41    fn detect_shell_variant(content: &[u8]) -> &'static str {
42        if let Some(first_line) = content.split(|&b| b == b'\n').next()
43            && first_line.starts_with(b"#!")
44        {
45            let lowered = String::from_utf8_lossy(first_line).to_lowercase();
46            if lowered.contains("bash") {
47                return "bash";
48            }
49            if lowered.contains("zsh") {
50                return "zsh";
51            }
52            if lowered.contains("sh") {
53                return "sh";
54            }
55        }
56        "sh"
57    }
58
59    fn is_exported(entry: &QueryMatch<'_>) -> bool {
60        let graph = entry.graph();
61        let edges = graph.edges().edges_to(entry.id);
62        edges
63            .iter()
64            .any(|edge| matches!(edge.kind, EdgeKind::Exports { .. }))
65    }
66
67    fn detect_shell_variant_for_entry(entry: &QueryMatch<'_>) -> Option<&'static str> {
68        let path = entry.file_path()?;
69        let content = fs::read(path).ok()?;
70        Some(Self::detect_shell_variant(&content))
71    }
72}
73
74impl Default for ShellPlugin {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80impl LanguagePlugin for ShellPlugin {
81    fn metadata(&self) -> LanguageMetadata {
82        LanguageMetadata {
83            id: LANGUAGE_ID,
84            name: LANGUAGE_NAME,
85            version: env!("CARGO_PKG_VERSION"),
86            author: "Verivus Pty Ltd",
87            description: "Shell script language support for sqry",
88            tree_sitter_version: TREE_SITTER_VERSION,
89        }
90    }
91
92    fn extensions(&self) -> &'static [&'static str] {
93        &["sh", "bash", "bashrc", "bash_profile", "profile", "env"]
94    }
95
96    fn language(&self) -> Language {
97        tree_sitter_bash::LANGUAGE.into()
98    }
99
100    fn parse_ast(&self, content: &[u8]) -> Result<Tree, ParseError> {
101        let mut parser = Parser::new();
102        parser
103            .set_language(&self.language())
104            .map_err(|err| ParseError::LanguageSetFailed(err.to_string()))?;
105
106        parser
107            .parse(content, None)
108            .ok_or(ParseError::TreeSitterFailed)
109    }
110
111    fn extract_scopes(
112        &self,
113        tree: &Tree,
114        content: &[u8],
115        file_path: &Path,
116    ) -> Result<Vec<Scope>, ScopeError> {
117        extract_shell_scopes(tree, content, file_path)
118    }
119
120    fn fields(&self) -> &'static [FieldDescriptor] {
121        &[
122            FieldDescriptor {
123                name: metadata_keys::SHELL_VARIANT,
124                field_type: FieldType::String,
125                operators: &[Operator::Equal],
126                indexed: false,
127                doc: "Shell variant (sh, bash, zsh) detected from shebang",
128            },
129            FieldDescriptor {
130                name: metadata_keys::IS_EXPORTED,
131                field_type: FieldType::Bool,
132                operators: &[Operator::Equal],
133                indexed: false,
134                doc: "Whether the symbol is exported (export keyword)",
135            },
136        ]
137    }
138
139    fn evaluate_field(
140        &self,
141        entry: &QueryMatch<'_>,
142        field: &str,
143        value: &Value,
144    ) -> PluginResult<bool> {
145        match field {
146            metadata_keys::SHELL_VARIANT => {
147                let actual = Self::detect_shell_variant_for_entry(entry).unwrap_or("sh");
148                match value {
149                    Value::String(expected) => Ok(actual == expected),
150                    _ => Ok(false),
151                }
152            }
153            metadata_keys::IS_EXPORTED => {
154                let is_exported = Self::is_exported(entry);
155                match value {
156                    Value::Boolean(expected) => Ok(is_exported == *expected),
157                    _ => Ok(false),
158                }
159            }
160            _ => Ok(false),
161        }
162    }
163
164    fn graph_builder(&self) -> Option<&dyn sqry_core::graph::GraphBuilder> {
165        Some(&self.graph_builder)
166    }
167}
168
169fn extract_shell_scopes(
170    tree: &Tree,
171    content: &[u8],
172    file_path: &Path,
173) -> Result<Vec<Scope>, ScopeError> {
174    let root_node = tree.root_node();
175    let language = tree_sitter_bash::LANGUAGE.into();
176
177    let scope_query = r"
178; Function definitions (both POSIX and Bash style)
179(function_definition
180  name: (word) @function.name
181) @function.type
182";
183
184    let query = Query::new(&language, scope_query)
185        .map_err(|e| ScopeError::QueryCompilationFailed(e.to_string()))?;
186
187    let mut scopes = Vec::new();
188    let mut cursor = QueryCursor::new();
189    let mut query_matches = cursor.matches(&query, root_node, content);
190
191    while let Some(m) = query_matches.next() {
192        let mut scope_type = None;
193        let mut scope_name = None;
194        let mut scope_start = None;
195        let mut scope_end = None;
196
197        for capture in m.captures {
198            let capture_name = query.capture_names()[capture.index as usize];
199            let node = capture.node;
200
201            let capture_ext = std::path::Path::new(capture_name)
202                .extension()
203                .and_then(|ext| ext.to_str());
204
205            if capture_ext.is_some_and(|ext| ext.eq_ignore_ascii_case("type")) {
206                scope_type = Some("function".to_string());
207                scope_start = Some(node.start_position());
208                scope_end = Some(node.end_position());
209            } else if capture_ext.is_some_and(|ext| ext.eq_ignore_ascii_case("name")) {
210                scope_name = node
211                    .utf8_text(content)
212                    .ok()
213                    .map(std::string::ToString::to_string);
214            }
215        }
216
217        if let (Some(stype), Some(sname), Some(start), Some(end)) =
218            (scope_type, scope_name, scope_start, scope_end)
219        {
220            let scope = Scope {
221                id: ScopeId::new(0),
222                scope_type: stype,
223                name: sname,
224                file_path: file_path.to_path_buf(),
225                start_line: start.row + 1,
226                start_column: start.column,
227                end_line: end.row + 1,
228                end_column: end.column,
229                parent_id: None,
230            };
231            scopes.push(scope);
232        }
233    }
234
235    scopes.sort_by_key(|s| (s.start_line, s.start_column));
236    link_nested_scopes(&mut scopes);
237    Ok(scopes)
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use sqry_core::graph::unified::build::staging::StagingGraph;
244    use sqry_core::graph::unified::build::test_helpers::{
245        assert_has_export_edge, collect_export_edges,
246    };
247    use std::fs;
248    use std::path::PathBuf;
249
250    fn read_fixture(name: &str) -> (Vec<u8>, PathBuf) {
251        let path = PathBuf::from("tests/fixtures").join(name);
252        let content = fs::read(&path).expect("failed to read fixture");
253        (content, path)
254    }
255
256    #[test]
257    fn graph_builder_exports_functions_and_variables() {
258        let (content, path) = read_fixture("basic.sh");
259        let plugin = ShellPlugin::default();
260        let tree = plugin.parse_ast(&content).expect("parse failed");
261        let mut staging = StagingGraph::new();
262        let builder = plugin.graph_builder().expect("graph builder");
263
264        builder
265            .build_graph(&tree, &content, &path, &mut staging)
266            .expect("build graph");
267
268        assert_has_export_edge(&staging, "basic::module", "foo");
269        assert_has_export_edge(&staging, "basic::module", "bar");
270        assert_has_export_edge(&staging, "basic::module", "DATA_PATH");
271    }
272
273    #[test]
274    fn graph_builder_uses_direct_exports() {
275        let (content, path) = read_fixture("basic.sh");
276        let plugin = ShellPlugin::default();
277        let tree = plugin.parse_ast(&content).expect("parse failed");
278        let mut staging = StagingGraph::new();
279        let builder = plugin.graph_builder().expect("graph builder");
280
281        builder
282            .build_graph(&tree, &content, &path, &mut staging)
283            .expect("build graph");
284
285        let exports = collect_export_edges(&staging);
286        assert!(!exports.is_empty(), "expected export edges");
287    }
288
289    #[test]
290    fn extract_scopes_reports_functions() {
291        let (content, path) = read_fixture("basic.sh");
292        let plugin = ShellPlugin::default();
293        let tree = plugin.parse_ast(&content).expect("parse failed");
294        let scopes = plugin
295            .extract_scopes(&tree, &content, &path)
296            .expect("scope extraction failed");
297
298        let names: Vec<String> = scopes.into_iter().map(|scope| scope.name).collect();
299        assert!(names.contains(&"foo".to_string()));
300        assert!(names.contains(&"bar".to_string()));
301    }
302}