1use sqry_core::ast::{Scope, ScopeId, link_nested_scopes};
10use sqry_core::plugin::{
11 LanguageMetadata, LanguagePlugin,
12 error::{ParseError, ScopeError},
13};
14use std::path::Path;
15use tree_sitter::{Language, Node, Parser, Tree};
16
17pub mod relations;
19
20pub use relations::SqlGraphBuilder;
21
22pub struct SqlPlugin {
38 graph_builder: SqlGraphBuilder,
39}
40
41impl SqlPlugin {
42 #[must_use]
44 pub fn new() -> Self {
45 Self {
46 graph_builder: SqlGraphBuilder,
47 }
48 }
49}
50
51impl Default for SqlPlugin {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl LanguagePlugin for SqlPlugin {
58 fn metadata(&self) -> LanguageMetadata {
59 LanguageMetadata {
60 id: "sql",
61 name: "SQL",
62 version: env!("CARGO_PKG_VERSION"),
63 author: "Verivus Pty Ltd",
64 description: "SQL language support for sqry - database schema and query search",
65 tree_sitter_version: "0.24",
66 }
67 }
68
69 fn extensions(&self) -> &'static [&'static str] {
70 &["sql"]
71 }
72
73 fn language(&self) -> Language {
74 tree_sitter_sequel::LANGUAGE.into()
75 }
76
77 fn parse_ast(&self, content: &[u8]) -> Result<Tree, ParseError> {
78 let mut parser = Parser::new();
79 let language = self.language();
80
81 parser.set_language(&language).map_err(|e| {
82 ParseError::LanguageSetFailed(format!("Failed to set SQL language: {e}"))
83 })?;
84
85 parser
86 .parse(content, None)
87 .ok_or(ParseError::TreeSitterFailed)
88 }
89
90 fn extract_scopes(
91 &self,
92 tree: &Tree,
93 content: &[u8],
94 file_path: &Path,
95 ) -> Result<Vec<Scope>, ScopeError> {
96 let mut scopes = Vec::new();
97 Self::collect_scopes(tree.root_node(), content, file_path, &mut scopes);
98
99 scopes.sort_by_key(|s| (s.start_line, s.start_column));
101 link_nested_scopes(&mut scopes);
102
103 Ok(scopes)
104 }
105
106 fn graph_builder(&self) -> Option<&dyn sqry_core::graph::GraphBuilder> {
107 Some(&self.graph_builder)
108 }
109}
110
111impl SqlPlugin {
112 fn collect_scopes(node: Node, content: &[u8], file_path: &Path, scopes: &mut Vec<Scope>) {
118 match node.kind() {
119 "create_function" => {
120 if let Some(name) = Self::extract_name_from_object_reference(&node, content) {
122 let start = node.start_position();
123 let end = node.end_position();
124
125 scopes.push(Scope {
126 id: ScopeId::new(0),
127 scope_type: "function".to_string(),
128 name,
129 file_path: file_path.to_path_buf(),
130 start_line: start.row + 1,
131 start_column: start.column,
132 end_line: end.row + 1,
133 end_column: end.column,
134 parent_id: None,
135 });
136 }
137 }
138 "create_trigger" => {
139 if let Some(name) = Self::extract_name_from_object_reference(&node, content) {
141 let start = node.start_position();
142 let end = node.end_position();
143
144 scopes.push(Scope {
145 id: ScopeId::new(0),
146 scope_type: "trigger".to_string(),
147 name,
148 file_path: file_path.to_path_buf(),
149 start_line: start.row + 1,
150 start_column: start.column,
151 end_line: end.row + 1,
152 end_column: end.column,
153 parent_id: None,
154 });
155 }
156 }
157 _ => {}
158 }
159
160 let mut cursor = node.walk();
162 for child in node.named_children(&mut cursor) {
163 Self::collect_scopes(child, content, file_path, scopes);
164 }
165 }
166
167 fn extract_name_from_object_reference(node: &Node, content: &[u8]) -> Option<String> {
169 let mut cursor = node.walk();
170 for child in node.named_children(&mut cursor) {
171 if child.kind() == "object_reference" {
172 let mut inner_cursor = child.walk();
174 for inner_child in child.named_children(&mut inner_cursor) {
175 if inner_child.kind() == "identifier"
176 && let Ok(text) = inner_child.utf8_text(content)
177 {
178 return Some(text.to_string());
179 }
180 }
181 if let Some(name_node) = child.child_by_field_name("name")
183 && let Ok(text) = name_node.utf8_text(content)
184 {
185 return Some(text.to_string());
186 }
187 }
188 }
189 None
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[test]
198 fn test_metadata() {
199 let plugin = SqlPlugin::default();
200 let metadata = plugin.metadata();
201
202 assert_eq!(metadata.id, "sql");
203 assert_eq!(metadata.name, "SQL");
204 assert_eq!(metadata.version, env!("CARGO_PKG_VERSION"));
205 assert_eq!(metadata.author, "Verivus Pty Ltd");
206 assert_eq!(metadata.tree_sitter_version, "0.24");
207 }
208
209 #[test]
210 fn test_extensions() {
211 let plugin = SqlPlugin::default();
212 let extensions = plugin.extensions();
213
214 assert_eq!(extensions.len(), 1);
215 assert!(extensions.contains(&"sql"));
216 }
217
218 #[test]
219 fn test_language() {
220 let plugin = SqlPlugin::default();
221 let language = plugin.language();
222
223 assert!(language.abi_version() > 0);
225 }
226
227 #[test]
228 fn test_parse_ast_simple() {
229 let plugin = SqlPlugin::default();
230 let source = b"CREATE TABLE users (id INT);";
231
232 let tree = plugin.parse_ast(source).unwrap();
233 assert!(!tree.root_node().has_error());
234 }
235
236 #[test]
237 fn test_plugin_is_send_sync() {
238 fn assert_send_sync<T: Send + Sync>() {}
239 assert_send_sync::<SqlPlugin>();
240 }
241
242 #[test]
243 fn test_extract_function_scope() {
244 use std::path::PathBuf;
245
246 let plugin = SqlPlugin::default();
247 let source = b"CREATE FUNCTION calculate_tax(amount DECIMAL)
248RETURNS DECIMAL
249AS $$ BEGIN RETURN amount * 0.1; END; $$ LANGUAGE plpgsql;";
250 let file = PathBuf::from("test.sql");
251
252 let tree = plugin.parse_ast(source).unwrap();
253 let scopes = plugin.extract_scopes(&tree, source, &file).unwrap();
254
255 let func_scope = scopes
257 .iter()
258 .find(|s| s.name == "calculate_tax" && s.scope_type == "function");
259 assert!(
260 func_scope.is_some(),
261 "calculate_tax function scope should be extracted, got: {:?}",
262 scopes
263 .iter()
264 .map(|s| (&s.name, &s.scope_type))
265 .collect::<Vec<_>>()
266 );
267
268 assert_eq!(
270 func_scope.unwrap().parent_id,
271 None,
272 "Top-level function scope should have parent_id = None"
273 );
274 }
275
276 #[test]
277 fn test_extract_trigger_scope() {
278 use std::path::PathBuf;
279
280 let plugin = SqlPlugin::default();
281 let source = b"CREATE TRIGGER update_timestamp
282BEFORE UPDATE ON users
283FOR EACH ROW
284EXECUTE FUNCTION update_modified_column();";
285 let file = PathBuf::from("test.sql");
286
287 let tree = plugin.parse_ast(source).unwrap();
288 let scopes = plugin.extract_scopes(&tree, source, &file).unwrap();
289
290 let trigger_scope = scopes
292 .iter()
293 .find(|s| s.name == "update_timestamp" && s.scope_type == "trigger");
294 assert!(
295 trigger_scope.is_some(),
296 "update_timestamp trigger scope should be extracted, got: {:?}",
297 scopes
298 .iter()
299 .map(|s| (&s.name, &s.scope_type))
300 .collect::<Vec<_>>()
301 );
302
303 assert_eq!(
305 trigger_scope.unwrap().parent_id,
306 None,
307 "Top-level trigger scope should have parent_id = None"
308 );
309 }
310
311 #[test]
312 fn test_multiple_scopes() {
313 use std::path::PathBuf;
314
315 let plugin = SqlPlugin::default();
316 let source = b"CREATE FUNCTION calculate_total(price DECIMAL)
318RETURNS DECIMAL AS $$ BEGIN RETURN price * 1.1; END; $$ LANGUAGE plpgsql;
319
320CREATE FUNCTION get_user_count(status VARCHAR)
321RETURNS INT AS $$ BEGIN RETURN 0; END; $$ LANGUAGE plpgsql;
322
323CREATE TRIGGER audit_changes
324BEFORE UPDATE ON users
325FOR EACH ROW EXECUTE FUNCTION log_update();";
326 let file = PathBuf::from("test.sql");
327
328 let tree = plugin.parse_ast(source).unwrap();
329 let scopes = plugin.extract_scopes(&tree, source, &file).unwrap();
330
331 let func_scopes: Vec<_> = scopes
333 .iter()
334 .filter(|s| s.scope_type == "function")
335 .collect();
336 let trigger_scopes: Vec<_> = scopes
337 .iter()
338 .filter(|s| s.scope_type == "trigger")
339 .collect();
340
341 assert!(
342 func_scopes.len() >= 2,
343 "Should have at least 2 function scopes, got: {} - names: {:?}",
344 func_scopes.len(),
345 func_scopes.iter().map(|s| &s.name).collect::<Vec<_>>()
346 );
347 assert!(
348 !trigger_scopes.is_empty(),
349 "Should have at least 1 trigger scope, got: {} - names: {:?}",
350 trigger_scopes.len(),
351 trigger_scopes.iter().map(|s| &s.name).collect::<Vec<_>>()
352 );
353 }
354}