Skip to main content

sqry_lang_sql/
lib.rs

1//! SQL language plugin for sqry
2//!
3//! Implements the `LanguagePlugin` trait for SQL, providing:
4//! - AST parsing with tree-sitter
5//!
6//! This plugin enables semantic code search for SQL codebases, the #5 priority
7//! language for universal database query and data management (100% adoption in data-driven companies).
8
9use sqry_core::ast::{Scope, ScopeId, link_nested_scopes};
10use sqry_core::plugin::{
11    LanguageMetadata, LanguagePlugin,
12    error::{ParseError, ScopeError},
13};
14use std::path::Path;
15use tree_sitter::{Language, Node, Parser, Tree};
16
17/// SQL relation extraction and graph building
18pub mod relations;
19
20pub use relations::SqlGraphBuilder;
21
22/// SQL language plugin
23///
24/// Provides language support for SQL files (.sql).
25///
26/// # Example
27///
28/// ```
29/// use sqry_lang_sql::SqlPlugin;
30/// use sqry_core::plugin::LanguagePlugin;
31///
32/// let plugin = SqlPlugin::new();
33/// let metadata = plugin.metadata();
34/// assert_eq!(metadata.id, "sql");
35/// assert_eq!(metadata.name, "SQL");
36/// ```
37pub struct SqlPlugin {
38    graph_builder: SqlGraphBuilder,
39}
40
41impl SqlPlugin {
42    /// Creates a new SQL plugin instance.
43    #[must_use]
44    pub fn new() -> Self {
45        Self {
46            graph_builder: SqlGraphBuilder,
47        }
48    }
49}
50
51impl Default for SqlPlugin {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl LanguagePlugin for SqlPlugin {
58    fn metadata(&self) -> LanguageMetadata {
59        LanguageMetadata {
60            id: "sql",
61            name: "SQL",
62            version: env!("CARGO_PKG_VERSION"),
63            author: "Verivus Pty Ltd",
64            description: "SQL language support for sqry - database schema and query search",
65            tree_sitter_version: "0.24",
66        }
67    }
68
69    fn extensions(&self) -> &'static [&'static str] {
70        &["sql"]
71    }
72
73    fn language(&self) -> Language {
74        tree_sitter_sequel::LANGUAGE.into()
75    }
76
77    fn parse_ast(&self, content: &[u8]) -> Result<Tree, ParseError> {
78        let mut parser = Parser::new();
79        let language = self.language();
80
81        parser.set_language(&language).map_err(|e| {
82            ParseError::LanguageSetFailed(format!("Failed to set SQL language: {e}"))
83        })?;
84
85        parser
86            .parse(content, None)
87            .ok_or(ParseError::TreeSitterFailed)
88    }
89
90    fn extract_scopes(
91        &self,
92        tree: &Tree,
93        content: &[u8],
94        file_path: &Path,
95    ) -> Result<Vec<Scope>, ScopeError> {
96        let mut scopes = Vec::new();
97        Self::collect_scopes(tree.root_node(), content, file_path, &mut scopes);
98
99        // Sort by position and link nested scopes
100        scopes.sort_by_key(|s| (s.start_line, s.start_column));
101        link_nested_scopes(&mut scopes);
102
103        Ok(scopes)
104    }
105
106    fn graph_builder(&self) -> Option<&dyn sqry_core::graph::GraphBuilder> {
107        Some(&self.graph_builder)
108    }
109}
110
111impl SqlPlugin {
112    /// Collect scope information from SQL AST nodes
113    ///
114    /// Extracts scopes for:
115    /// - Functions (`create_function`) - including stored procedures
116    /// - Triggers (`create_trigger`)
117    fn collect_scopes(node: Node, content: &[u8], file_path: &Path, scopes: &mut Vec<Scope>) {
118        match node.kind() {
119            "create_function" => {
120                // Extract function name from object_reference
121                if let Some(name) = Self::extract_name_from_object_reference(&node, content) {
122                    let start = node.start_position();
123                    let end = node.end_position();
124
125                    scopes.push(Scope {
126                        id: ScopeId::new(0),
127                        scope_type: "function".to_string(),
128                        name,
129                        file_path: file_path.to_path_buf(),
130                        start_line: start.row + 1,
131                        start_column: start.column,
132                        end_line: end.row + 1,
133                        end_column: end.column,
134                        parent_id: None,
135                    });
136                }
137            }
138            "create_trigger" => {
139                // Extract trigger name from object_reference
140                if let Some(name) = Self::extract_name_from_object_reference(&node, content) {
141                    let start = node.start_position();
142                    let end = node.end_position();
143
144                    scopes.push(Scope {
145                        id: ScopeId::new(0),
146                        scope_type: "trigger".to_string(),
147                        name,
148                        file_path: file_path.to_path_buf(),
149                        start_line: start.row + 1,
150                        start_column: start.column,
151                        end_line: end.row + 1,
152                        end_column: end.column,
153                        parent_id: None,
154                    });
155                }
156            }
157            _ => {}
158        }
159
160        // Recurse into children
161        let mut cursor = node.walk();
162        for child in node.named_children(&mut cursor) {
163            Self::collect_scopes(child, content, file_path, scopes);
164        }
165    }
166
167    /// Extract name from an `object_reference` child node
168    fn extract_name_from_object_reference(node: &Node, content: &[u8]) -> Option<String> {
169        let mut cursor = node.walk();
170        for child in node.named_children(&mut cursor) {
171            if child.kind() == "object_reference" {
172                // Look for the identifier with name field
173                let mut inner_cursor = child.walk();
174                for inner_child in child.named_children(&mut inner_cursor) {
175                    if inner_child.kind() == "identifier"
176                        && let Ok(text) = inner_child.utf8_text(content)
177                    {
178                        return Some(text.to_string());
179                    }
180                }
181                // Also try the name field
182                if let Some(name_node) = child.child_by_field_name("name")
183                    && let Ok(text) = name_node.utf8_text(content)
184                {
185                    return Some(text.to_string());
186                }
187            }
188        }
189        None
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    #[test]
198    fn test_metadata() {
199        let plugin = SqlPlugin::default();
200        let metadata = plugin.metadata();
201
202        assert_eq!(metadata.id, "sql");
203        assert_eq!(metadata.name, "SQL");
204        assert_eq!(metadata.version, env!("CARGO_PKG_VERSION"));
205        assert_eq!(metadata.author, "Verivus Pty Ltd");
206        assert_eq!(metadata.tree_sitter_version, "0.24");
207    }
208
209    #[test]
210    fn test_extensions() {
211        let plugin = SqlPlugin::default();
212        let extensions = plugin.extensions();
213
214        assert_eq!(extensions.len(), 1);
215        assert!(extensions.contains(&"sql"));
216    }
217
218    #[test]
219    fn test_language() {
220        let plugin = SqlPlugin::default();
221        let language = plugin.language();
222
223        // Just verify we can get a language (ABI version should be non-zero)
224        assert!(language.abi_version() > 0);
225    }
226
227    #[test]
228    fn test_parse_ast_simple() {
229        let plugin = SqlPlugin::default();
230        let source = b"CREATE TABLE users (id INT);";
231
232        let tree = plugin.parse_ast(source).unwrap();
233        assert!(!tree.root_node().has_error());
234    }
235
236    #[test]
237    fn test_plugin_is_send_sync() {
238        fn assert_send_sync<T: Send + Sync>() {}
239        assert_send_sync::<SqlPlugin>();
240    }
241
242    #[test]
243    fn test_extract_function_scope() {
244        use std::path::PathBuf;
245
246        let plugin = SqlPlugin::default();
247        let source = b"CREATE FUNCTION calculate_tax(amount DECIMAL)
248RETURNS DECIMAL
249AS $$ BEGIN RETURN amount * 0.1; END; $$ LANGUAGE plpgsql;";
250        let file = PathBuf::from("test.sql");
251
252        let tree = plugin.parse_ast(source).unwrap();
253        let scopes = plugin.extract_scopes(&tree, source, &file).unwrap();
254
255        // Check that function scope is extracted
256        let func_scope = scopes
257            .iter()
258            .find(|s| s.name == "calculate_tax" && s.scope_type == "function");
259        assert!(
260            func_scope.is_some(),
261            "calculate_tax function scope should be extracted, got: {:?}",
262            scopes
263                .iter()
264                .map(|s| (&s.name, &s.scope_type))
265                .collect::<Vec<_>>()
266        );
267
268        // Top-level function scopes should have no parent
269        assert_eq!(
270            func_scope.unwrap().parent_id,
271            None,
272            "Top-level function scope should have parent_id = None"
273        );
274    }
275
276    #[test]
277    fn test_extract_trigger_scope() {
278        use std::path::PathBuf;
279
280        let plugin = SqlPlugin::default();
281        let source = b"CREATE TRIGGER update_timestamp
282BEFORE UPDATE ON users
283FOR EACH ROW
284EXECUTE FUNCTION update_modified_column();";
285        let file = PathBuf::from("test.sql");
286
287        let tree = plugin.parse_ast(source).unwrap();
288        let scopes = plugin.extract_scopes(&tree, source, &file).unwrap();
289
290        // Check that trigger scope is extracted
291        let trigger_scope = scopes
292            .iter()
293            .find(|s| s.name == "update_timestamp" && s.scope_type == "trigger");
294        assert!(
295            trigger_scope.is_some(),
296            "update_timestamp trigger scope should be extracted, got: {:?}",
297            scopes
298                .iter()
299                .map(|s| (&s.name, &s.scope_type))
300                .collect::<Vec<_>>()
301        );
302
303        // Top-level trigger scopes should have no parent
304        assert_eq!(
305            trigger_scope.unwrap().parent_id,
306            None,
307            "Top-level trigger scope should have parent_id = None"
308        );
309    }
310
311    #[test]
312    fn test_multiple_scopes() {
313        use std::path::PathBuf;
314
315        let plugin = SqlPlugin::default();
316        // Use more complete function syntax with parameters
317        let source = b"CREATE FUNCTION calculate_total(price DECIMAL)
318RETURNS DECIMAL AS $$ BEGIN RETURN price * 1.1; END; $$ LANGUAGE plpgsql;
319
320CREATE FUNCTION get_user_count(status VARCHAR)
321RETURNS INT AS $$ BEGIN RETURN 0; END; $$ LANGUAGE plpgsql;
322
323CREATE TRIGGER audit_changes
324BEFORE UPDATE ON users
325FOR EACH ROW EXECUTE FUNCTION log_update();";
326        let file = PathBuf::from("test.sql");
327
328        let tree = plugin.parse_ast(source).unwrap();
329        let scopes = plugin.extract_scopes(&tree, source, &file).unwrap();
330
331        // Check that all scopes are extracted
332        let func_scopes: Vec<_> = scopes
333            .iter()
334            .filter(|s| s.scope_type == "function")
335            .collect();
336        let trigger_scopes: Vec<_> = scopes
337            .iter()
338            .filter(|s| s.scope_type == "trigger")
339            .collect();
340
341        assert!(
342            func_scopes.len() >= 2,
343            "Should have at least 2 function scopes, got: {} - names: {:?}",
344            func_scopes.len(),
345            func_scopes.iter().map(|s| &s.name).collect::<Vec<_>>()
346        );
347        assert!(
348            !trigger_scopes.is_empty(),
349            "Should have at least 1 trigger scope, got: {} - names: {:?}",
350            trigger_scopes.len(),
351            trigger_scopes.iter().map(|s| &s.name).collect::<Vec<_>>()
352        );
353    }
354}