1use sqry_core::graph::{
10 GraphBuilder, GraphBuilderError, GraphResult, Language, Position, Span,
11 unified::{GraphBuildHelper, StagingGraph},
12};
13use std::path::Path;
14use streaming_iterator::StreamingIterator;
15use tree_sitter::{Query, QueryCursor, Tree};
16
17#[derive(Debug, Clone)]
18struct SqlCallable {
19 node_id: sqry_core::graph::unified::NodeId,
20 start_byte: usize,
21 end_byte: usize,
22}
23
24#[derive(Debug, Clone)]
25struct SqlDatabaseObject {
26 node_id: sqry_core::graph::unified::NodeId,
27}
28
29#[derive(Debug, Clone)]
30enum SqlTableOpKind {
31 Read,
32 Write(sqry_core::graph::unified::TableWriteOp),
33}
34
35#[derive(Debug, Clone)]
36struct SqlTableOp {
37 op_span_bytes: (usize, usize),
38 kind: SqlTableOpKind,
39 table_name: String,
40 schema: Option<String>,
41 table_node_id: sqry_core::graph::unified::NodeId,
42 span: Span,
43}
44
45const FILE_MODULE_NAME: &str = "<file_module>";
50
51#[derive(Debug, Default, Clone, Copy)]
60pub struct SqlGraphBuilder;
61
62impl SqlGraphBuilder {
63 #[must_use]
65 pub fn new() -> Self {
66 Self
67 }
68}
69
70impl GraphBuilder for SqlGraphBuilder {
71 fn build_graph(
72 &self,
73 tree: &Tree,
74 content: &[u8],
75 file: &Path,
76 staging: &mut StagingGraph,
77 ) -> GraphResult<()> {
78 let mut helper = GraphBuildHelper::new(staging, file, Language::Sql);
80
81 let language = tree_sitter_sequel::LANGUAGE.into();
83 let queries = SqlQueries::new(&language)?;
84
85 let mut callables = extract_procedures(tree, content, &queries.procedures, &mut helper);
87
88 callables.extend(extract_triggers(
90 tree,
91 content,
92 &queries.triggers,
93 &mut helper,
94 ));
95
96 let table_reads = extract_table_reads(tree, content, &queries.table_reads, &mut helper);
98
99 let table_writes = extract_table_writes(tree, content, &queries.table_writes, &mut helper);
101
102 let function_calls = extract_function_calls(tree, content, &queries.function_calls);
104
105 let table_definitions =
107 extract_table_definitions(tree, content, &queries.table_definitions, &mut helper);
108
109 let view_definitions =
111 extract_view_definitions(tree, content, &queries.view_definitions, &mut helper);
112
113 for op in table_reads.into_iter().chain(table_writes) {
115 let Some(caller) = find_enclosing_callable(&callables, op.op_span_bytes) else {
116 continue;
117 };
118
119 match op.kind {
120 SqlTableOpKind::Read => helper.add_table_read_edge_with_span(
121 caller.node_id,
122 op.table_node_id,
123 &op.table_name,
124 op.schema.as_deref(),
125 vec![op.span],
126 ),
127 SqlTableOpKind::Write(operation) => helper.add_table_write_edge_with_span(
128 caller.node_id,
129 op.table_node_id,
130 &op.table_name,
131 op.schema.as_deref(),
132 operation,
133 vec![op.span],
134 ),
135 }
136 }
137
138 for call in function_calls {
140 if let Some(caller) = find_enclosing_callable(&callables, call.span_bytes) {
142 let callee_id =
144 helper.add_function(&call.callee_name, Some(call.span), false, false);
145 helper.add_call_edge_full_with_span(
146 caller.node_id,
147 callee_id,
148 255,
149 false,
150 vec![call.span],
151 );
152 }
153 }
156
157 extract_trigger_execute_function_calls(
160 tree,
161 content,
162 &queries.trigger_execute_function,
163 &callables,
164 &mut helper,
165 );
166
167 emit_exports(
169 &mut helper,
170 &callables,
171 &table_definitions,
172 &view_definitions,
173 );
174
175 Ok(())
176 }
177
178 fn language(&self) -> Language {
179 Language::Sql
180 }
181}
182
183struct SqlQueries {
185 procedures: Query,
186 triggers: Query,
187 trigger_execute_function: Query,
188 table_reads: Query,
189 table_writes: Query,
190 function_calls: Query,
191 table_definitions: Query,
192 view_definitions: Query,
193}
194
195impl SqlQueries {
196 #[allow(clippy::too_many_lines)]
198 fn new(language: &tree_sitter::Language) -> GraphResult<Self> {
199 let procedures = Query::new(
201 language,
202 r"
203 (create_function
204 (object_reference
205 name: (identifier) @func.name)) @func
206 ",
207 )
208 .map_err(|e| GraphBuilderError::ParseError {
209 span: Span::default(),
210 reason: format!("Failed to compile procedure query: {e}"),
211 })?;
212
213 let triggers = Query::new(
216 language,
217 r"
218 (create_trigger
219 (object_reference
220 name: (identifier) @trigger.name)
221 (keyword_on)
222 (object_reference
223 name: (identifier) @trigger.table)) @trigger
224 ",
225 )
226 .map_err(|e| GraphBuilderError::ParseError {
227 span: Span::default(),
228 reason: format!("Failed to compile trigger query: {e}"),
229 })?;
230
231 let trigger_execute_function = Query::new(
234 language,
235 r"
236 (create_trigger
237 (object_reference
238 name: (identifier) @trigger.name)
239 (keyword_execute)
240 (keyword_function)
241 (object_reference
242 name: (identifier) @func.name)) @trigger_exec
243 ",
244 )
245 .map_err(|e| GraphBuilderError::ParseError {
246 span: Span::default(),
247 reason: format!("Failed to compile trigger_execute_function query: {e}"),
248 })?;
249
250 let table_reads = Query::new(
253 language,
254 r"
255 (statement
256 (select) @select
257 (from
258 (keyword_from)
259 (relation
260 (object_reference
261 name: (identifier) @table.name))))
262 ",
263 )
264 .map_err(|e| GraphBuilderError::ParseError {
265 span: Span::default(),
266 reason: format!("Failed to compile table_reads query: {e}"),
267 })?;
268
269 let table_writes = Query::new(
274 language,
275 r"
276 [
277 (insert
278 (object_reference
279 name: (identifier) @table.name)) @write
280
281 (update
282 (relation
283 (object_reference
284 name: (identifier) @table.name))) @write
285
286 (statement
287 (delete) @write
288 (from
289 (keyword_from)
290 (object_reference
291 name: (identifier) @table.name)))
292 ]
293 ",
294 )
295 .map_err(|e| GraphBuilderError::ParseError {
296 span: Span::default(),
297 reason: format!("Failed to compile table_writes query: {e}"),
298 })?;
299
300 let function_calls = Query::new(
306 language,
307 r#"
308 [
309 (invocation
310 (object_reference
311 name: (identifier) @call.name)) @call
312
313 (ERROR
314 ":="
315 (_) @call.name
316 "(") @call
317
318 (ERROR) @call.error
319 ]
320 "#,
321 )
322 .map_err(|e| GraphBuilderError::ParseError {
323 span: Span::default(),
324 reason: format!("Failed to compile function_calls query: {e}"),
325 })?;
326
327 let table_definitions = Query::new(
329 language,
330 r"
331 (create_table
332 (object_reference
333 name: (identifier) @table.name)) @table
334 ",
335 )
336 .map_err(|e| GraphBuilderError::ParseError {
337 span: Span::default(),
338 reason: format!("Failed to compile table_definitions query: {e}"),
339 })?;
340
341 let view_definitions = Query::new(
343 language,
344 r"
345 [
346 (create_view
347 (object_reference
348 name: (identifier) @view.name)) @view
349 (create_materialized_view
350 (object_reference
351 name: (identifier) @view.name)) @view
352 ]
353 ",
354 )
355 .map_err(|e| GraphBuilderError::ParseError {
356 span: Span::default(),
357 reason: format!("Failed to compile view_definitions query: {e}"),
358 })?;
359
360 Ok(Self {
361 procedures,
362 triggers,
363 trigger_execute_function,
364 table_reads,
365 table_writes,
366 function_calls,
367 table_definitions,
368 view_definitions,
369 })
370 }
371}
372
373fn extract_procedures(
375 tree: &Tree,
376 content: &[u8],
377 query: &Query,
378 helper: &mut GraphBuildHelper,
379) -> Vec<SqlCallable> {
380 let mut callables = Vec::new();
381 let mut cursor = QueryCursor::new();
382 let capture_names = query.capture_names();
383 let mut matches = cursor.matches(query, tree.root_node(), content);
384
385 while let Some(m) = matches.next() {
386 let mut func_name = None;
387 let mut func_node = None;
388
389 for capture in m.captures {
390 let name = capture_names[capture.index as usize];
391 if name == "func.name"
392 && let Ok(text) = capture.node.utf8_text(content)
393 {
394 func_name = Some(text.to_string());
395 }
396 if name == "func" {
397 func_node = Some(capture.node);
398 }
399 }
400
401 if let (Some(name), Some(node)) = (func_name, func_node) {
402 let span = Span::from_node(&node);
403 let node_id = helper.add_function(&name, Some(span), false, false);
404 callables.push(SqlCallable {
405 node_id,
406 start_byte: node.start_byte(),
407 end_byte: node.end_byte(),
408 });
409 }
410 }
411
412 callables
413}
414
415fn extract_triggers(
417 tree: &Tree,
418 content: &[u8],
419 query: &Query,
420 helper: &mut GraphBuildHelper,
421) -> Vec<SqlCallable> {
422 let mut callables = Vec::new();
423 let mut cursor = QueryCursor::new();
424 let capture_names = query.capture_names();
425 let mut matches = cursor.matches(query, tree.root_node(), content);
426
427 while let Some(m) = matches.next() {
428 let mut trigger_name = None;
429 let mut table_name = None;
430 let mut trigger_node = None;
431
432 for capture in m.captures {
433 let name = capture_names[capture.index as usize];
434 match name {
435 "trigger.name" => {
436 if let Ok(text) = capture.node.utf8_text(content) {
437 trigger_name = Some(text.to_string());
438 }
439 }
440 "trigger.table" => {
441 if let Ok(text) = capture.node.utf8_text(content) {
442 table_name = Some(text.to_string());
443 }
444 }
445 "trigger" => {
446 trigger_node = Some(capture.node);
447 }
448 _ => {}
449 }
450 }
451
452 if let (Some(trigger), Some(table), Some(node)) = (trigger_name, table_name, trigger_node) {
453 let (schema, table_only) = split_schema_table(&table);
454 let span = Span::from_node(&node);
455
456 let trigger_id = helper.add_function(&trigger, Some(span), false, false);
457 callables.push(SqlCallable {
458 node_id: trigger_id,
459 start_byte: node.start_byte(),
460 end_byte: node.end_byte(),
461 });
462
463 let table_id = helper.add_variable(table_only, Some(span));
464 helper.add_triggered_by_edge_with_span(
465 trigger_id,
466 table_id,
467 &trigger,
468 schema,
469 vec![span],
470 );
471 }
472 }
473
474 callables
475}
476
477fn extract_trigger_execute_function_calls(
482 tree: &Tree,
483 content: &[u8],
484 query: &Query,
485 callables: &[SqlCallable],
486 helper: &mut GraphBuildHelper,
487) {
488 let mut cursor = QueryCursor::new();
489 let capture_names = query.capture_names();
490 let mut matches = cursor.matches(query, tree.root_node(), content);
491
492 while let Some(m) = matches.next() {
493 let mut trigger_name = None;
494 let mut func_name = None;
495 let mut trigger_node = None;
496
497 for capture in m.captures {
498 let name = capture_names[capture.index as usize];
499 match name {
500 "trigger.name" => {
501 if let Ok(text) = capture.node.utf8_text(content) {
502 trigger_name = Some(text.to_string());
503 }
504 }
505 "func.name" => {
506 if let Ok(text) = capture.node.utf8_text(content) {
507 func_name = Some(text.to_string());
508 }
509 }
510 "trigger_exec" => {
511 trigger_node = Some(capture.node);
512 }
513 _ => {}
514 }
515 }
516
517 if let (Some(_trigger), Some(func), Some(node)) = (trigger_name, func_name, trigger_node) {
518 let span = Span::from_node(&node);
519
520 if let Some(trigger_callable) = callables.iter().find(|c| {
522 c.start_byte <= node.start_byte() && node.end_byte() <= c.end_byte
524 }) {
525 let callee_id = helper.add_function(&func, Some(span), false, false);
527 helper.add_call_edge_full_with_span(
528 trigger_callable.node_id,
529 callee_id,
530 255,
531 false,
532 vec![span],
533 );
534 }
535 }
536 }
537}
538
539fn extract_table_reads(
541 tree: &Tree,
542 content: &[u8],
543 query: &Query,
544 helper: &mut GraphBuildHelper,
545) -> Vec<SqlTableOp> {
546 let mut ops = Vec::new();
547 let mut cursor = QueryCursor::new();
548 let capture_names = query.capture_names();
549 let mut matches = cursor.matches(query, tree.root_node(), content);
550
551 while let Some(m) = matches.next() {
552 let mut table_name = None;
553 let mut op_node = None;
554
555 for capture in m.captures {
556 let name = capture_names[capture.index as usize];
557 match name {
558 "table.name" => {
559 if let Ok(text) = capture.node.utf8_text(content) {
560 table_name = Some(text.to_string());
561 }
562 }
563 "select" => op_node = Some(capture.node),
564 _ => {}
565 }
566 }
567
568 if let (Some(table_name), Some(node)) = (table_name, op_node) {
569 let (schema, table_only) = split_schema_table(&table_name);
570 let span = Span::from_node(&node);
571 let table_node_id = helper.add_variable(table_only, Some(span));
572 ops.push(SqlTableOp {
573 op_span_bytes: (node.start_byte(), node.end_byte()),
574 kind: SqlTableOpKind::Read,
575 table_name: table_only.to_string(),
576 schema: schema.map(str::to_string),
577 table_node_id,
578 span,
579 });
580 }
581 }
582
583 ops
584}
585
586fn extract_table_writes(
588 tree: &Tree,
589 content: &[u8],
590 query: &Query,
591 helper: &mut GraphBuildHelper,
592) -> Vec<SqlTableOp> {
593 let mut ops = Vec::new();
594 let mut cursor = QueryCursor::new();
595 let capture_names = query.capture_names();
596 let mut matches = cursor.matches(query, tree.root_node(), content);
597
598 while let Some(m) = matches.next() {
599 let mut table_name = None;
600 let mut write_node = None;
601
602 for capture in m.captures {
603 let name = capture_names[capture.index as usize];
604 match name {
605 "table.name" => {
606 if let Ok(text) = capture.node.utf8_text(content) {
607 table_name = Some(text.to_string());
608 }
609 }
610 "write" => write_node = Some(capture.node),
611 _ => {}
612 }
613 }
614
615 let Some(table_name) = table_name else {
616 continue;
617 };
618 let Some(node) = write_node else {
619 continue;
620 };
621
622 let operation = match node.kind() {
623 "insert" => sqry_core::graph::unified::TableWriteOp::Insert,
624 "delete" => sqry_core::graph::unified::TableWriteOp::Delete,
625 _ => sqry_core::graph::unified::TableWriteOp::Update,
626 };
627
628 let (schema, table_only) = split_schema_table(&table_name);
629 let span = Span::from_node(&node);
630 let table_node_id = helper.add_variable(table_only, Some(span));
631 ops.push(SqlTableOp {
632 op_span_bytes: (node.start_byte(), node.end_byte()),
633 kind: SqlTableOpKind::Write(operation),
634 table_name: table_only.to_string(),
635 schema: schema.map(str::to_string),
636 table_node_id,
637 span,
638 });
639 }
640
641 ops
642}
643
644#[derive(Debug)]
646struct SqlFunctionCall {
647 callee_name: String,
648 span_bytes: (usize, usize),
649 span: Span,
650}
651
652fn extract_function_calls(tree: &Tree, content: &[u8], query: &Query) -> Vec<SqlFunctionCall> {
654 let mut calls = Vec::new();
655 let mut cursor = QueryCursor::new();
656 let capture_names = query.capture_names();
657 let mut matches = cursor.matches(query, tree.root_node(), content);
658
659 while let Some(m) = matches.next() {
660 let mut call_name = None;
661 let mut call_node = None;
662
663 for capture in m.captures {
664 let name = capture_names[capture.index as usize];
665 match name {
666 "call.name" => {
667 if let Ok(text) = capture.node.utf8_text(content) {
668 call_name = Some(normalize_callee_name(text));
669 }
670 }
671 "call" | "call.error" => call_node = Some(capture.node),
672 _ => {}
673 }
674 }
675
676 let Some(node) = call_node else {
677 continue;
678 };
679
680 let span_bytes = (node.start_byte(), node.end_byte());
681 let span = Span::from_node(&node);
682
683 if node.kind() == "ERROR" {
684 if let Ok(text) = node.utf8_text(content) {
685 for name in extract_error_call_names(text) {
686 calls.push(SqlFunctionCall {
687 callee_name: name,
688 span_bytes,
689 span,
690 });
691 }
692 }
693 continue;
694 }
695
696 if let Some(name) = call_name
697 && !name.is_empty()
698 {
699 calls.push(SqlFunctionCall {
700 callee_name: name,
701 span_bytes,
702 span,
703 });
704 }
705 }
706
707 calls
708}
709
710fn normalize_callee_name(name: &str) -> String {
711 name.trim()
712 .rsplit('.')
713 .next()
714 .unwrap_or_default()
715 .trim()
716 .to_string()
717}
718
719fn extract_error_call_names(text: &str) -> Vec<String> {
720 let bytes = text.as_bytes();
721 let mut offset = 0;
722 let mut call_names = Vec::new();
723
724 while offset < bytes.len() {
725 if !is_sql_identifier_start(bytes[offset]) {
726 offset += 1;
727 continue;
728 }
729
730 let start = offset;
731 offset += 1;
732 while offset < bytes.len() && is_sql_identifier_continue(bytes[offset]) {
733 offset += 1;
734 }
735
736 let token = &text[start..offset];
737 let mut lookahead = offset;
738 while lookahead < bytes.len() && bytes[lookahead].is_ascii_whitespace() {
739 lookahead += 1;
740 }
741
742 if lookahead < bytes.len() && bytes[lookahead] == b'(' {
743 let normalized = normalize_callee_name(token);
744 if !normalized.is_empty() && !call_names.iter().any(|name| name == &normalized) {
745 call_names.push(normalized);
746 }
747 }
748 }
749
750 call_names
751}
752
753const fn is_sql_identifier_start(byte: u8) -> bool {
754 byte.is_ascii_alphabetic() || byte == b'_'
755}
756
757const fn is_sql_identifier_continue(byte: u8) -> bool {
758 byte.is_ascii_alphanumeric() || matches!(byte, b'_' | b'.')
759}
760
761fn extract_table_definitions(
763 tree: &Tree,
764 content: &[u8],
765 query: &Query,
766 helper: &mut GraphBuildHelper,
767) -> Vec<SqlDatabaseObject> {
768 let mut objects = Vec::new();
769 let mut cursor = QueryCursor::new();
770 let capture_names = query.capture_names();
771 let mut matches = cursor.matches(query, tree.root_node(), content);
772
773 while let Some(m) = matches.next() {
774 let mut table_name = None;
775 let mut table_node = None;
776
777 for capture in m.captures {
778 let name = capture_names[capture.index as usize];
779 match name {
780 "table.name" => {
781 if let Ok(text) = capture.node.utf8_text(content) {
782 table_name = Some(text.to_string());
783 }
784 }
785 "table" => table_node = Some(capture.node),
786 _ => {}
787 }
788 }
789
790 if let (Some(name), Some(node)) = (table_name, table_node) {
791 let (_, table_only) = split_schema_table(&name);
793 let span = Span::from_node(&node);
794 let node_id = helper.add_variable(table_only, Some(span));
795 objects.push(SqlDatabaseObject { node_id });
796 }
797 }
798
799 objects
800}
801
802fn extract_view_definitions(
804 tree: &Tree,
805 content: &[u8],
806 query: &Query,
807 helper: &mut GraphBuildHelper,
808) -> Vec<SqlDatabaseObject> {
809 let mut objects = Vec::new();
810 let mut cursor = QueryCursor::new();
811 let capture_names = query.capture_names();
812 let mut matches = cursor.matches(query, tree.root_node(), content);
813
814 while let Some(m) = matches.next() {
815 let mut view_name = None;
816 let mut view_node = None;
817
818 for capture in m.captures {
819 let name = capture_names[capture.index as usize];
820 match name {
821 "view.name" => {
822 if let Ok(text) = capture.node.utf8_text(content) {
823 view_name = Some(text.to_string());
824 }
825 }
826 "view" => view_node = Some(capture.node),
827 _ => {}
828 }
829 }
830
831 if let (Some(name), Some(node)) = (view_name, view_node) {
832 let (_, view_only) = split_schema_table(&name);
834 let span = Span::from_node(&node);
835 let node_id = helper.add_variable(view_only, Some(span));
836 objects.push(SqlDatabaseObject { node_id });
837 }
838 }
839
840 objects
841}
842
843fn find_enclosing_callable(
844 callables: &[SqlCallable],
845 op_span_bytes: (usize, usize),
846) -> Option<&SqlCallable> {
847 let (start_byte, end_byte) = op_span_bytes;
848 callables
849 .iter()
850 .filter(|c| c.start_byte <= start_byte && end_byte <= c.end_byte)
851 .min_by_key(|c| c.end_byte.saturating_sub(c.start_byte))
852}
853
854fn split_schema_table(name: &str) -> (Option<&str>, &str) {
855 let mut parts = name.splitn(2, '.');
856 let first = parts.next().unwrap_or(name).trim();
857 let second = parts.next().map(str::trim);
858 match second {
859 Some(table) if !table.is_empty() => (Some(first), table),
860 _ => (None, first),
861 }
862}
863
864trait SpanExt {
866 fn from_node(node: &tree_sitter::Node) -> Self;
867}
868
869impl SpanExt for Span {
870 fn from_node(node: &tree_sitter::Node) -> Self {
871 Span::new(
872 Position::new(node.start_position().row, node.start_position().column),
873 Position::new(node.end_position().row, node.end_position().column),
874 )
875 }
876}
877
878fn emit_exports(
884 helper: &mut GraphBuildHelper,
885 callables: &[SqlCallable],
886 tables: &[SqlDatabaseObject],
887 views: &[SqlDatabaseObject],
888) {
889 if callables.is_empty() && tables.is_empty() && views.is_empty() {
891 return;
892 }
893
894 let module_id = helper.add_module(FILE_MODULE_NAME, None);
896
897 for callable in callables {
899 helper.add_export_edge(module_id, callable.node_id);
900 }
901
902 for table in tables {
904 helper.add_export_edge(module_id, table.node_id);
905 }
906
907 for view in views {
909 helper.add_export_edge(module_id, view.node_id);
910 }
911}
912
913#[cfg(test)]
914mod tests {
915 use super::*;
916 use sqry_core::graph::unified::StagingOp;
917 use sqry_core::graph::unified::TableWriteOp;
918 use sqry_core::graph::unified::edge::EdgeKind;
919 use std::path::PathBuf;
920
921 fn parse_sql(sql: &str) -> Tree {
922 let mut parser = tree_sitter::Parser::new();
923 parser
924 .set_language(&tree_sitter_sequel::LANGUAGE.into())
925 .expect("Failed to set SQL language");
926 parser
927 .parse(sql.as_bytes(), None)
928 .expect("Failed to parse SQL")
929 }
930
931 #[allow(dead_code)]
933 fn get_table_read_edges(staging: &StagingGraph) -> Vec<String> {
934 staging
935 .operations()
936 .iter()
937 .filter_map(|op| {
938 if let StagingOp::AddEdge {
939 kind: EdgeKind::TableRead { table_name, .. },
940 ..
941 } = op
942 {
943 Some(format!("TableRead({:?})", table_name))
946 } else {
947 None
948 }
949 })
950 .collect()
951 }
952
953 #[allow(dead_code)]
955 fn get_table_write_edges(staging: &StagingGraph) -> Vec<(String, TableWriteOp)> {
956 staging
957 .operations()
958 .iter()
959 .filter_map(|op| {
960 if let StagingOp::AddEdge {
961 kind:
962 EdgeKind::TableWrite {
963 table_name,
964 operation,
965 ..
966 },
967 ..
968 } = op
969 {
970 Some((format!("TableWrite({:?})", table_name), *operation))
971 } else {
972 None
973 }
974 })
975 .collect()
976 }
977
978 fn count_table_read_edges(staging: &StagingGraph) -> usize {
980 staging
981 .operations()
982 .iter()
983 .filter(|op| {
984 matches!(
985 op,
986 StagingOp::AddEdge {
987 kind: EdgeKind::TableRead { .. },
988 ..
989 }
990 )
991 })
992 .count()
993 }
994
995 fn count_table_write_edges(staging: &StagingGraph) -> usize {
996 staging
997 .operations()
998 .iter()
999 .filter(|op| {
1000 matches!(
1001 op,
1002 StagingOp::AddEdge {
1003 kind: EdgeKind::TableWrite { .. },
1004 ..
1005 }
1006 )
1007 })
1008 .count()
1009 }
1010
1011 fn count_table_write_edges_by_op(staging: &StagingGraph, expected_op: TableWriteOp) -> usize {
1012 staging
1013 .operations()
1014 .iter()
1015 .filter(|op| {
1016 matches!(
1017 op,
1018 StagingOp::AddEdge { kind: EdgeKind::TableWrite { operation, .. }, .. }
1019 if *operation == expected_op
1020 )
1021 })
1022 .count()
1023 }
1024
1025 fn count_call_edges(staging: &StagingGraph) -> usize {
1026 staging
1027 .operations()
1028 .iter()
1029 .filter(|op| {
1030 matches!(
1031 op,
1032 StagingOp::AddEdge {
1033 kind: EdgeKind::Calls { .. },
1034 ..
1035 }
1036 )
1037 })
1038 .count()
1039 }
1040
1041 fn count_export_edges(staging: &StagingGraph) -> usize {
1043 staging
1044 .operations()
1045 .iter()
1046 .filter(|op| {
1047 matches!(
1048 op,
1049 StagingOp::AddEdge {
1050 kind: EdgeKind::Exports { .. },
1051 ..
1052 }
1053 )
1054 })
1055 .count()
1056 }
1057
1058 #[test]
1059 fn test_sql_graph_builder_new() {
1060 let builder = SqlGraphBuilder::new();
1061 assert_eq!(builder.language(), Language::Sql);
1062 }
1063
1064 #[test]
1065 fn test_select_creates_table_read_edge() {
1066 let sql = r"
1067 CREATE FUNCTION get_users()
1068 RETURNS TABLE (id INT, name TEXT) AS $$
1069 SELECT * FROM users;
1070 $$ LANGUAGE sql;
1071 ";
1072
1073 let tree = parse_sql(sql);
1074 let mut staging = StagingGraph::new();
1075 let builder = SqlGraphBuilder::new();
1076 let file = PathBuf::from("test.sql");
1077
1078 builder
1079 .build_graph(&tree, sql.as_bytes(), &file, &mut staging)
1080 .expect("Graph building should succeed");
1081
1082 let read_count = count_table_read_edges(&staging);
1083 assert!(
1084 read_count >= 1,
1085 "Expected at least 1 TableRead edge, got {read_count}"
1086 );
1087 }
1088
1089 #[test]
1090 fn test_insert_creates_table_write_edge() {
1091 let sql = r"
1092 CREATE FUNCTION create_user(user_name TEXT)
1093 RETURNS VOID AS $$
1094 INSERT INTO users (name) VALUES (user_name);
1095 $$ LANGUAGE sql;
1096 ";
1097
1098 let tree = parse_sql(sql);
1099 let mut staging = StagingGraph::new();
1100 let builder = SqlGraphBuilder::new();
1101 let file = PathBuf::from("test.sql");
1102
1103 builder
1104 .build_graph(&tree, sql.as_bytes(), &file, &mut staging)
1105 .expect("Graph building should succeed");
1106
1107 let insert_count = count_table_write_edges_by_op(&staging, TableWriteOp::Insert);
1108 assert!(
1109 insert_count >= 1,
1110 "Expected at least 1 TableWrite(Insert) edge, got {insert_count}"
1111 );
1112 }
1113
1114 #[test]
1115 fn test_update_creates_table_write_edge() {
1116 let sql = r"
1117 CREATE FUNCTION update_user(user_id INT, new_name TEXT)
1118 RETURNS VOID AS $$
1119 UPDATE users SET name = new_name WHERE id = user_id;
1120 $$ LANGUAGE sql;
1121 ";
1122
1123 let tree = parse_sql(sql);
1124 let mut staging = StagingGraph::new();
1125 let builder = SqlGraphBuilder::new();
1126 let file = PathBuf::from("test.sql");
1127
1128 builder
1129 .build_graph(&tree, sql.as_bytes(), &file, &mut staging)
1130 .expect("Graph building should succeed");
1131
1132 let update_count = count_table_write_edges_by_op(&staging, TableWriteOp::Update);
1133 assert!(
1134 update_count >= 1,
1135 "Expected at least 1 TableWrite(Update) edge, got {update_count}"
1136 );
1137 }
1138
1139 #[test]
1140 fn test_delete_creates_table_write_edge() {
1141 let sql = r"
1142 CREATE FUNCTION delete_user(user_id INT)
1143 RETURNS VOID AS $$
1144 DELETE FROM users WHERE id = user_id;
1145 $$ LANGUAGE sql;
1146 ";
1147
1148 let tree = parse_sql(sql);
1149 let mut staging = StagingGraph::new();
1150 let builder = SqlGraphBuilder::new();
1151 let file = PathBuf::from("test.sql");
1152
1153 builder
1154 .build_graph(&tree, sql.as_bytes(), &file, &mut staging)
1155 .expect("Graph building should succeed");
1156
1157 let delete_count = count_table_write_edges_by_op(&staging, TableWriteOp::Delete);
1158 assert!(
1159 delete_count >= 1,
1160 "Expected at least 1 TableWrite(Delete) edge, got {delete_count}"
1161 );
1162 }
1163
1164 #[test]
1165 fn test_join_creates_table_read_edge_for_primary_table() {
1166 let sql = r"
1169 CREATE FUNCTION get_user_orders()
1170 RETURNS TABLE (user_name TEXT, order_id INT) AS $$
1171 SELECT u.name, o.id FROM users u JOIN orders o ON u.id = o.user_id;
1172 $$ LANGUAGE sql;
1173 ";
1174
1175 let tree = parse_sql(sql);
1176 let mut staging = StagingGraph::new();
1177 let builder = SqlGraphBuilder::new();
1178 let file = PathBuf::from("test.sql");
1179
1180 builder
1181 .build_graph(&tree, sql.as_bytes(), &file, &mut staging)
1182 .expect("Graph building should succeed");
1183
1184 let read_count = count_table_read_edges(&staging);
1186 assert!(
1187 read_count >= 1,
1188 "Expected at least 1 TableRead edge for FROM clause, got {read_count}"
1189 );
1190 }
1191
1192 #[test]
1193 fn test_multiple_joins_creates_table_read_edge_for_primary_table() {
1194 let sql = r"
1196 CREATE FUNCTION get_order_details()
1197 RETURNS TABLE (user_name TEXT, product_name TEXT, quantity INT) AS $$
1198 SELECT u.name, p.name, o.quantity
1199 FROM users u
1200 JOIN orders o ON u.id = o.user_id
1201 LEFT JOIN products p ON o.product_id = p.id;
1202 $$ LANGUAGE sql;
1203 ";
1204
1205 let tree = parse_sql(sql);
1206 let mut staging = StagingGraph::new();
1207 let builder = SqlGraphBuilder::new();
1208 let file = PathBuf::from("test.sql");
1209
1210 builder
1211 .build_graph(&tree, sql.as_bytes(), &file, &mut staging)
1212 .expect("Graph building should succeed");
1213
1214 let read_count = count_table_read_edges(&staging);
1216 assert!(
1217 read_count >= 1,
1218 "Expected at least 1 TableRead edge for FROM clause, got {read_count}"
1219 );
1220 }
1221
1222 #[test]
1223 fn test_mixed_read_write_operations() {
1224 let sql = r"
1226 CREATE FUNCTION transfer_funds(from_id INT, to_id INT, amount DECIMAL)
1227 RETURNS VOID AS $$
1228 BEGIN
1229 SELECT balance FROM accounts WHERE id = from_id;
1230 UPDATE accounts SET balance = balance - amount WHERE id = from_id;
1231 UPDATE accounts SET balance = balance + amount WHERE id = to_id;
1232 INSERT INTO transactions (from_account, to_account, amount) VALUES (from_id, to_id, amount);
1233 END;
1234 $$ LANGUAGE plpgsql;
1235 ";
1236
1237 let tree = parse_sql(sql);
1238 let mut staging = StagingGraph::new();
1239 let builder = SqlGraphBuilder::new();
1240 let file = PathBuf::from("test.sql");
1241
1242 builder
1243 .build_graph(&tree, sql.as_bytes(), &file, &mut staging)
1244 .expect("Graph building should succeed");
1245
1246 let read_count = count_table_read_edges(&staging);
1247 let write_count = count_table_write_edges(&staging);
1248
1249 assert!(
1251 read_count >= 1,
1252 "Expected at least 1 TableRead edge, got {read_count}"
1253 );
1254 assert!(
1255 write_count >= 1,
1256 "Expected at least 1 TableWrite edge, got {write_count}"
1257 );
1258 }
1259
1260 #[test]
1261 fn test_plpgsql_assignment_function_calls_create_call_edges() {
1262 let sql = r"
1263 CREATE FUNCTION add(a INT, b INT) RETURNS INT AS $$
1264 BEGIN
1265 RETURN a + b;
1266 END;
1267 $$ LANGUAGE plpgsql;
1268
1269 CREATE FUNCTION multiply(a INT, b INT) RETURNS INT AS $$
1270 BEGIN
1271 RETURN a * b;
1272 END;
1273 $$ LANGUAGE plpgsql;
1274
1275 CREATE FUNCTION compute(x INT, y INT, z INT) RETURNS INT AS $$
1276 DECLARE
1277 sum_val INT;
1278 BEGIN
1279 sum_val := add(x, y);
1280 RETURN multiply(sum_val, z);
1281 END;
1282 $$ LANGUAGE plpgsql;
1283 ";
1284
1285 let tree = parse_sql(sql);
1286 let mut staging = StagingGraph::new();
1287 let builder = SqlGraphBuilder::new();
1288 let file = PathBuf::from("nested_calls.sql");
1289
1290 builder
1291 .build_graph(&tree, sql.as_bytes(), &file, &mut staging)
1292 .expect("Graph building should succeed");
1293
1294 let call_count = count_call_edges(&staging);
1295 assert!(
1296 call_count >= 2,
1297 "Expected at least 2 call edges for add() and multiply(), got {call_count}"
1298 );
1299 }
1300
1301 #[test]
1302 fn test_plpgsql_multiple_assignment_calls_create_call_edges() {
1303 let sql = r"
1304 CREATE FUNCTION helper_one() RETURNS INT AS $$
1305 BEGIN
1306 RETURN 42;
1307 END;
1308 $$ LANGUAGE plpgsql;
1309
1310 CREATE FUNCTION helper_two() RETURNS INT AS $$
1311 BEGIN
1312 RETURN 100;
1313 END;
1314 $$ LANGUAGE plpgsql;
1315
1316 CREATE FUNCTION orchestrator() RETURNS INT AS $$
1317 DECLARE
1318 val1 INT;
1319 val2 INT;
1320 BEGIN
1321 val1 := helper_one();
1322 val2 := helper_two();
1323 RETURN val1 + val2;
1324 END;
1325 $$ LANGUAGE plpgsql;
1326 ";
1327
1328 let tree = parse_sql(sql);
1329 let mut staging = StagingGraph::new();
1330 let builder = SqlGraphBuilder::new();
1331 let file = PathBuf::from("multiple_assignment_calls.sql");
1332
1333 builder
1334 .build_graph(&tree, sql.as_bytes(), &file, &mut staging)
1335 .expect("Graph building should succeed");
1336
1337 let call_count = count_call_edges(&staging);
1338 assert!(
1339 call_count >= 2,
1340 "Expected at least 2 call edges for helper_one() and helper_two(), got {call_count}"
1341 );
1342 }
1343
1344 #[test]
1345 fn test_schema_qualified_table_name() {
1346 let sql = r"
1347 CREATE FUNCTION get_public_users()
1348 RETURNS TABLE (id INT, name TEXT) AS $$
1349 SELECT * FROM public.users;
1350 $$ LANGUAGE sql;
1351 ";
1352
1353 let tree = parse_sql(sql);
1354 let mut staging = StagingGraph::new();
1355 let builder = SqlGraphBuilder::new();
1356 let file = PathBuf::from("test.sql");
1357
1358 let result = builder.build_graph(&tree, sql.as_bytes(), &file, &mut staging);
1360 assert!(result.is_ok(), "Should handle schema-qualified table names");
1361 }
1362
1363 #[test]
1364 fn test_split_schema_table_with_schema() {
1365 let (schema, table) = split_schema_table("public.users");
1366 assert_eq!(schema, Some("public"));
1367 assert_eq!(table, "users");
1368 }
1369
1370 #[test]
1371 fn test_split_schema_table_without_schema() {
1372 let (schema, table) = split_schema_table("users");
1373 assert_eq!(schema, None);
1374 assert_eq!(table, "users");
1375 }
1376
1377 #[test]
1378 fn test_split_schema_table_with_whitespace() {
1379 let (schema, table) = split_schema_table(" public . users ");
1380 assert_eq!(schema, Some("public"));
1381 assert_eq!(table, "users");
1382 }
1383
1384 #[test]
1385 fn test_empty_sql_file() {
1386 let sql = "";
1387 let tree = parse_sql(sql);
1388 let mut staging = StagingGraph::new();
1389 let builder = SqlGraphBuilder::new();
1390 let file = PathBuf::from("empty.sql");
1391
1392 let result = builder.build_graph(&tree, sql.as_bytes(), &file, &mut staging);
1393 assert!(result.is_ok(), "Should handle empty SQL files");
1394 }
1395
1396 #[test]
1397 fn test_standalone_select_without_function() {
1398 let sql = "SELECT * FROM users;";
1400
1401 let tree = parse_sql(sql);
1402 let mut staging = StagingGraph::new();
1403 let builder = SqlGraphBuilder::new();
1404 let file = PathBuf::from("query.sql");
1405
1406 builder
1407 .build_graph(&tree, sql.as_bytes(), &file, &mut staging)
1408 .expect("Graph building should succeed");
1409
1410 let read_count = count_table_read_edges(&staging);
1413 assert_eq!(
1414 read_count, 0,
1415 "Standalone SELECT should not create edges without enclosing function"
1416 );
1417 }
1418
1419 #[test]
1420 fn test_export_edges_for_table_definitions() {
1421 let sql = r"
1422 CREATE TABLE users (
1423 id SERIAL PRIMARY KEY,
1424 name TEXT NOT NULL
1425 );
1426
1427 CREATE TABLE orders (
1428 id SERIAL PRIMARY KEY,
1429 user_id INTEGER REFERENCES users(id)
1430 );
1431 ";
1432
1433 let tree = parse_sql(sql);
1434 let mut staging = StagingGraph::new();
1435 let builder = SqlGraphBuilder::new();
1436 let file = PathBuf::from("schema.sql");
1437
1438 builder
1439 .build_graph(&tree, sql.as_bytes(), &file, &mut staging)
1440 .expect("Graph building should succeed");
1441
1442 let export_count = count_export_edges(&staging);
1443 assert_eq!(
1444 export_count, 2,
1445 "Expected 2 Export edges (users and orders), got {export_count}"
1446 );
1447 }
1448
1449 #[test]
1450 fn test_export_edges_for_view_definitions() {
1451 let sql = r"
1452 CREATE TABLE users (id INT, created_at TIMESTAMP);
1453
1454 CREATE VIEW active_users AS
1455 SELECT * FROM users WHERE created_at > NOW() - INTERVAL '30 days';
1456
1457 CREATE MATERIALIZED VIEW user_stats AS
1458 SELECT COUNT(*) as total FROM users;
1459 ";
1460
1461 let tree = parse_sql(sql);
1462 let mut staging = StagingGraph::new();
1463 let builder = SqlGraphBuilder::new();
1464 let file = PathBuf::from("views.sql");
1465
1466 builder
1467 .build_graph(&tree, sql.as_bytes(), &file, &mut staging)
1468 .expect("Graph building should succeed");
1469
1470 let export_count = count_export_edges(&staging);
1471 assert_eq!(
1473 export_count, 3,
1474 "Expected 3 Export edges (1 table + 2 views), got {export_count}"
1475 );
1476 }
1477
1478 #[test]
1479 fn test_export_edges_for_functions_and_triggers() {
1480 let sql = r"
1481 CREATE FUNCTION get_balance(account_id INT) RETURNS BIGINT AS $$
1482 BEGIN
1483 RETURN 42;
1484 END;
1485 $$ LANGUAGE plpgsql;
1486
1487 CREATE FUNCTION update_balance() RETURNS TRIGGER AS $$
1488 BEGIN
1489 RETURN NEW;
1490 END;
1491 $$ LANGUAGE plpgsql;
1492
1493 CREATE TRIGGER balance_updated
1494 BEFORE INSERT ON accounts
1495 FOR EACH ROW
1496 EXECUTE FUNCTION update_balance();
1497 ";
1498
1499 let tree = parse_sql(sql);
1500 let mut staging = StagingGraph::new();
1501 let builder = SqlGraphBuilder::new();
1502 let file = PathBuf::from("banking.sql");
1503
1504 builder
1505 .build_graph(&tree, sql.as_bytes(), &file, &mut staging)
1506 .expect("Graph building should succeed");
1507
1508 let export_count = count_export_edges(&staging);
1509 assert!(
1513 export_count >= 3,
1514 "Expected at least 3 Export edges (2 functions + 1 trigger), got {export_count}"
1515 );
1516 }
1517
1518 #[test]
1519 fn test_export_edges_with_schema_qualified_names() {
1520 let sql = r"
1521 CREATE TABLE public.customers (
1522 id SERIAL PRIMARY KEY,
1523 name TEXT NOT NULL
1524 );
1525
1526 CREATE FUNCTION public.get_customer_name(cust_id INT) RETURNS TEXT AS $$
1527 BEGIN
1528 RETURN 'test';
1529 END;
1530 $$ LANGUAGE plpgsql;
1531 ";
1532
1533 let tree = parse_sql(sql);
1534 let mut staging = StagingGraph::new();
1535 let builder = SqlGraphBuilder::new();
1536 let file = PathBuf::from("public_schema.sql");
1537
1538 builder
1539 .build_graph(&tree, sql.as_bytes(), &file, &mut staging)
1540 .expect("Graph building should succeed");
1541
1542 let export_count = count_export_edges(&staging);
1543 assert_eq!(
1545 export_count, 2,
1546 "Expected 2 Export edges (table + function), got {export_count}"
1547 );
1548 }
1549
1550 #[test]
1551 fn test_mixed_database_objects_exports() {
1552 let sql = r"
1553 CREATE TABLE accounts (
1554 id SERIAL PRIMARY KEY,
1555 balance_cents BIGINT NOT NULL
1556 );
1557
1558 CREATE VIEW positive_balances AS
1559 SELECT * FROM accounts WHERE balance_cents > 0;
1560
1561 CREATE FUNCTION get_balance(account_id INT) RETURNS BIGINT AS $$
1562 BEGIN
1563 RETURN (SELECT balance_cents FROM accounts WHERE id = account_id);
1564 END;
1565 $$ LANGUAGE plpgsql;
1566 ";
1567
1568 let tree = parse_sql(sql);
1569 let mut staging = StagingGraph::new();
1570 let builder = SqlGraphBuilder::new();
1571 let file = PathBuf::from("mixed.sql");
1572
1573 builder
1574 .build_graph(&tree, sql.as_bytes(), &file, &mut staging)
1575 .expect("Graph building should succeed");
1576
1577 let export_count = count_export_edges(&staging);
1578 assert_eq!(
1580 export_count, 3,
1581 "Expected 3 Export edges (table + view + function), got {export_count}"
1582 );
1583 }
1584
1585 #[test]
1586 fn test_no_exports_for_empty_file() {
1587 let sql = "";
1588 let tree = parse_sql(sql);
1589 let mut staging = StagingGraph::new();
1590 let builder = SqlGraphBuilder::new();
1591 let file = PathBuf::from("empty.sql");
1592
1593 builder
1594 .build_graph(&tree, sql.as_bytes(), &file, &mut staging)
1595 .expect("Graph building should succeed");
1596
1597 let export_count = count_export_edges(&staging);
1598 assert_eq!(
1599 export_count, 0,
1600 "Expected 0 Export edges for empty file, got {export_count}"
1601 );
1602 }
1603}