1use rayon::prelude::*;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::ffi::OsStr;
14use std::path::{Path, PathBuf};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ImpactedFile {
22 pub path: String,
24 pub tables_referenced: Vec<String>,
26 pub columns_referenced: Vec<String>,
29 pub hits: Vec<QueryHit>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct QueryHit {
35 pub line: usize,
36 pub snippet: String,
37 pub match_type: MatchType,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub enum MatchType {
42 SqlLiteral,
44 OrmReference,
46 FieldReference,
48}
49
50#[derive(Debug, Clone, Default, Serialize, Deserialize)]
51pub struct ImpactReport {
52 pub files_scanned: usize,
54 pub impacted_files: Vec<ImpactedFile>,
56 pub table_file_map: HashMap<String, Vec<String>>,
58 pub column_file_map: HashMap<String, Vec<String>>,
60}
61
62const SOURCE_EXTENSIONS: &[&str] = &[
68 "rs", "go", "py", "js", "ts", "jsx", "tsx", "rb", "java", "cs", "php", "sql", "graphql",
69];
70
71pub struct ImpactScanner {
72 tables: Vec<String>,
74 columns: Vec<String>,
76}
77
78impl ImpactScanner {
79 pub fn new(tables: Vec<String>, columns: Vec<String>) -> Self {
84 Self::new_with_options(tables, columns, true)
85 }
86
87 pub fn new_scan_short(tables: Vec<String>, columns: Vec<String>) -> Self {
89 Self::new_with_options(tables, columns, false)
90 }
91
92 fn new_with_options(tables: Vec<String>, columns: Vec<String>, skip_short: bool) -> Self {
93 let filter = |idents: Vec<String>| -> Vec<String> {
94 idents
95 .into_iter()
96 .filter(|s| !skip_short || s.chars().count() >= 4)
97 .map(|s| s.to_lowercase())
98 .collect()
99 };
100 Self {
101 tables: filter(tables),
102 columns: filter(columns),
103 }
104 }
105
106 pub fn scan(&self, root_dir: &Path) -> ImpactReport {
109 let paths = collect_source_files(root_dir);
111 let total = paths.len();
112
113 let impacted_files: Vec<ImpactedFile> = paths
114 .par_iter()
115 .filter_map(|path| self.scan_file(path))
116 .collect();
117
118 let mut table_file_map: HashMap<String, Vec<String>> = HashMap::new();
120 let mut column_file_map: HashMap<String, Vec<String>> = HashMap::new();
121 for f in &impacted_files {
122 for t in &f.tables_referenced {
123 table_file_map
124 .entry(t.clone())
125 .or_default()
126 .push(f.path.clone());
127 }
128 for c in &f.columns_referenced {
129 column_file_map
130 .entry(c.clone())
131 .or_default()
132 .push(f.path.clone());
133 }
134 }
135
136 ImpactReport {
137 files_scanned: total,
138 impacted_files,
139 table_file_map,
140 column_file_map,
141 }
142 }
143
144 fn scan_file(&self, path: &Path) -> Option<ImpactedFile> {
147 let content = std::fs::read_to_string(path).ok()?;
148 let content_lower = content.to_lowercase();
149
150 let mut tables_found: Vec<String> = Vec::new();
151 let mut columns_found: Vec<String> = Vec::new();
152 let mut hits: Vec<QueryHit> = Vec::new();
153
154 for (line_idx, line) in content.lines().enumerate() {
155 let line_lower = line.to_lowercase();
156
157 for table in &self.tables {
158 if line_lower.contains(table.as_str()) {
159 if !tables_found.contains(table) {
160 tables_found.push(table.clone());
161 }
162 let match_type = classify_match(&line_lower, table);
163 hits.push(QueryHit {
164 line: line_idx + 1,
165 snippet: line.trim().chars().take(200).collect(),
166 match_type,
167 });
168 }
169 }
170
171 for col in &self.columns {
172 if line_lower.contains(col.as_str())
173 && !content_lower.contains(&format!("-- {}", col))
174 {
175 if !columns_found.contains(col) {
176 columns_found.push(col.clone());
177 }
178 if !hits.iter().any(|h| h.line == line_idx + 1) {
180 let match_type = classify_match(&line_lower, col);
181 hits.push(QueryHit {
182 line: line_idx + 1,
183 snippet: line.trim().chars().take(200).collect(),
184 match_type,
185 });
186 }
187 }
188 }
189 }
190
191 if tables_found.is_empty() && columns_found.is_empty() {
192 return None;
193 }
194
195 let rel_path = path.to_string_lossy().to_string();
196
197 Some(ImpactedFile {
198 path: rel_path,
199 tables_referenced: tables_found,
200 columns_referenced: columns_found,
201 hits,
202 })
203 }
204}
205
206fn classify_match(line: &str, token: &str) -> MatchType {
209 let orm_patterns = [
211 "select(",
212 "where(",
213 "findone",
214 "findall",
215 "findmany",
216 "create(",
217 "update(",
218 "delete(",
219 "include:",
220 "prisma.",
221 "model.",
222 ".query(",
223 "execute(",
224 "from(",
225 "join(",
226 "diesel::",
227 "querybuilder",
228 "activerecord",
229 "sqlalchemy",
230 ];
231
232 let field_patterns = ["include:", "select:", "fields:", "columns:", "attributes:"];
233
234 if field_patterns.iter().any(|p| line.contains(p)) {
235 return MatchType::FieldReference;
236 }
237
238 if orm_patterns.iter().any(|p| line.contains(p)) {
239 return MatchType::OrmReference;
240 }
241
242 let sql_keywords = ["from ", "join ", "into ", "update ", "\"", "'", "`"];
244 if sql_keywords.iter().any(|k| {
245 if let Some(pos) = line.find(k) {
246 line[pos..].contains(token)
247 } else {
248 false
249 }
250 }) {
251 return MatchType::SqlLiteral;
252 }
253
254 MatchType::OrmReference
255}
256
257fn collect_source_files(root: &Path) -> Vec<PathBuf> {
260 let mut files = Vec::new();
261 collect_recursive(root, &mut files);
262 files
263}
264
265fn collect_recursive(dir: &Path, out: &mut Vec<PathBuf>) {
266 let Ok(entries) = std::fs::read_dir(dir) else {
267 return;
268 };
269
270 for entry in entries.flatten() {
271 let path = entry.path();
272
273 let name = path.file_name().and_then(OsStr::to_str).unwrap_or("");
275 if name.starts_with('.')
276 || matches!(
277 name,
278 "node_modules" | "target" | "dist" | "build" | "vendor" | "__pycache__" | ".git"
279 )
280 {
281 continue;
282 }
283
284 if path.is_dir() {
285 collect_recursive(&path, out);
286 } else if let Some(ext) = path.extension().and_then(OsStr::to_str) {
287 if SOURCE_EXTENSIONS.contains(&ext) {
288 out.push(path);
289 }
290 }
291 }
292}
293
294#[derive(Debug, Clone, Serialize, Deserialize)]
300pub struct ExtractedSql {
301 pub source_file: String,
303 pub line: usize,
305 pub column: Option<usize>,
307 pub sql: String,
309 pub context: SqlContext,
311 pub confidence: f32,
313}
314
315#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
317pub enum SqlContext {
318 RawSql,
320 PrismaRaw,
322 TypeOrm,
324 Sequelize,
326 SqlAlchemy,
328 Gorm,
330 Diesel,
332 EntityFramework,
334 Eloquent,
336 ActiveRecord,
338 Unknown,
340}
341
342impl std::fmt::Display for SqlContext {
343 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344 match self {
345 SqlContext::RawSql => write!(f, "Raw SQL"),
346 SqlContext::PrismaRaw => write!(f, "Prisma"),
347 SqlContext::TypeOrm => write!(f, "TypeORM"),
348 SqlContext::Sequelize => write!(f, "Sequelize"),
349 SqlContext::SqlAlchemy => write!(f, "SQLAlchemy"),
350 SqlContext::Gorm => write!(f, "GORM"),
351 SqlContext::Diesel => write!(f, "Diesel"),
352 SqlContext::EntityFramework => write!(f, "Entity Framework"),
353 SqlContext::Eloquent => write!(f, "Eloquent"),
354 SqlContext::ActiveRecord => write!(f, "ActiveRecord"),
355 SqlContext::Unknown => write!(f, "Unknown"),
356 }
357 }
358}
359
360#[derive(Debug, Clone, Default, Serialize, Deserialize)]
362pub struct SqlExtractionReport {
363 pub files_scanned: usize,
365 pub extracted: Vec<ExtractedSql>,
367 pub dangerous: Vec<ExtractedSql>,
369 pub by_context: HashMap<String, usize>,
371}
372
373pub struct SqlExtractor {
375 patterns: Vec<SqlExtractionPattern>,
377}
378
379struct SqlExtractionPattern {
380 regex: regex::Regex,
381 context: SqlContext,
382 extensions: Vec<&'static str>,
384 capture_group: usize,
386 confidence: f32,
388}
389
390impl Default for SqlExtractor {
391 fn default() -> Self {
392 Self::new()
393 }
394}
395
396impl SqlExtractor {
397 pub fn new() -> Self {
399 let patterns = Self::build_patterns();
400 Self { patterns }
401 }
402
403 fn build_patterns() -> Vec<SqlExtractionPattern> {
404 let mut patterns = Vec::new();
405
406 let mut add =
408 |pattern: &str, ctx: SqlContext, exts: &[&'static str], group: usize, conf: f32| {
409 if let Ok(re) = regex::Regex::new(pattern) {
410 patterns.push(SqlExtractionPattern {
411 regex: re,
412 context: ctx,
413 extensions: exts.to_vec(),
414 capture_group: group,
415 confidence: conf,
416 });
417 }
418 };
419
420 add(
422 r#"\$queryRaw\s*`([^`]+)`"#,
423 SqlContext::PrismaRaw,
424 &["ts", "js", "tsx", "jsx"],
425 1,
426 0.95,
427 );
428 add(
429 r#"\$executeRaw\s*`([^`]+)`"#,
430 SqlContext::PrismaRaw,
431 &["ts", "js", "tsx", "jsx"],
432 1,
433 0.95,
434 );
435 add(
436 r#"Prisma\.sql\s*`([^`]+)`"#,
437 SqlContext::PrismaRaw,
438 &["ts", "js", "tsx", "jsx"],
439 1,
440 0.9,
441 );
442
443 add(
445 r#"\.query\s*\(\s*["'`]([^"'`]+)["'`]"#,
446 SqlContext::TypeOrm,
447 &["ts", "js", "tsx", "jsx"],
448 1,
449 0.85,
450 );
451 add(
452 r#"createQueryBuilder\s*\(\s*["']([^"']+)["']"#,
453 SqlContext::TypeOrm,
454 &["ts", "js", "tsx", "jsx"],
455 1,
456 0.8,
457 );
458 add(
459 r#"\.createQueryRunner\(\)\.query\s*\(\s*["'`]([^"'`]+)["'`]"#,
460 SqlContext::TypeOrm,
461 &["ts", "js"],
462 1,
463 0.9,
464 );
465
466 add(
468 r#"sequelize\.query\s*\(\s*["'`]([^"'`]+)["'`]"#,
469 SqlContext::Sequelize,
470 &["ts", "js", "tsx", "jsx"],
471 1,
472 0.9,
473 );
474 add(
475 r#"QueryTypes\.\w+.*["'`]([^"'`]+)["'`]"#,
476 SqlContext::Sequelize,
477 &["ts", "js"],
478 1,
479 0.85,
480 );
481
482 add(
484 r#"text\s*\(\s*["']([^"']+)["']"#,
485 SqlContext::SqlAlchemy,
486 &["py"],
487 1,
488 0.9,
489 );
490 add(
491 r#"execute\s*\(\s*["']([^"']+)["']"#,
492 SqlContext::SqlAlchemy,
493 &["py"],
494 1,
495 0.85,
496 );
497 add(
498 r#"session\.execute\s*\(\s*["']([^"']+)["']"#,
499 SqlContext::SqlAlchemy,
500 &["py"],
501 1,
502 0.9,
503 );
504 add(
505 r#"connection\.execute\s*\(\s*["']([^"']+)["']"#,
506 SqlContext::SqlAlchemy,
507 &["py"],
508 1,
509 0.9,
510 );
511
512 add(
514 r#"cursor\.execute\s*\(\s*["']([^"']+)["']"#,
515 SqlContext::SqlAlchemy,
516 &["py"],
517 1,
518 0.9,
519 );
520 add(
521 r#"\.raw\s*\(\s*["']([^"']+)["']"#,
522 SqlContext::SqlAlchemy,
523 &["py"],
524 1,
525 0.85,
526 );
527
528 add(
530 r#"\.Raw\s*\(\s*["'`]([^"'`]+)["'`]"#,
531 SqlContext::Gorm,
532 &["go"],
533 1,
534 0.9,
535 );
536 add(
537 r#"\.Exec\s*\(\s*["'`]([^"'`]+)["'`]"#,
538 SqlContext::Gorm,
539 &["go"],
540 1,
541 0.9,
542 );
543 add(
544 r#"db\.Query\s*\(\s*["'`]([^"'`]+)["'`]"#,
545 SqlContext::Gorm,
546 &["go"],
547 1,
548 0.85,
549 );
550
551 add(
553 r#"sql_query\s*\(\s*["']([^"']+)["']"#,
554 SqlContext::Diesel,
555 &["rs"],
556 1,
557 0.9,
558 );
559 add(
560 r#"diesel::sql_query\s*\(\s*["']([^"']+)["']"#,
561 SqlContext::Diesel,
562 &["rs"],
563 1,
564 0.95,
565 );
566
567 add(
569 r#"\.FromSqlRaw\s*\(\s*["']([^"']+)["']"#,
570 SqlContext::EntityFramework,
571 &["cs"],
572 1,
573 0.9,
574 );
575 add(
576 r#"\.ExecuteSqlRaw\s*\(\s*["']([^"']+)["']"#,
577 SqlContext::EntityFramework,
578 &["cs"],
579 1,
580 0.9,
581 );
582 add(
583 r#"SqlQuery<[^>]+>\s*\(\s*["']([^"']+)["']"#,
584 SqlContext::EntityFramework,
585 &["cs"],
586 1,
587 0.85,
588 );
589
590 add(
592 r#"DB::raw\s*\(\s*["']([^"']+)["']"#,
593 SqlContext::Eloquent,
594 &["php"],
595 1,
596 0.9,
597 );
598 add(
599 r#"DB::statement\s*\(\s*["']([^"']+)["']"#,
600 SqlContext::Eloquent,
601 &["php"],
602 1,
603 0.9,
604 );
605 add(
606 r#"DB::select\s*\(\s*["']([^"']+)["']"#,
607 SqlContext::Eloquent,
608 &["php"],
609 1,
610 0.85,
611 );
612 add(
613 r#"DB::insert\s*\(\s*["']([^"']+)["']"#,
614 SqlContext::Eloquent,
615 &["php"],
616 1,
617 0.85,
618 );
619 add(
620 r#"DB::update\s*\(\s*["']([^"']+)["']"#,
621 SqlContext::Eloquent,
622 &["php"],
623 1,
624 0.85,
625 );
626 add(
627 r#"DB::delete\s*\(\s*["']([^"']+)["']"#,
628 SqlContext::Eloquent,
629 &["php"],
630 1,
631 0.85,
632 );
633
634 add(
636 r#"execute\s*\(\s*["']([^"']+)["']"#,
637 SqlContext::ActiveRecord,
638 &["rb"],
639 1,
640 0.85,
641 );
642 add(
643 r#"exec_query\s*\(\s*["']([^"']+)["']"#,
644 SqlContext::ActiveRecord,
645 &["rb"],
646 1,
647 0.9,
648 );
649 add(
650 r#"connection\.execute\s*\(\s*["']([^"']+)["']"#,
651 SqlContext::ActiveRecord,
652 &["rb"],
653 1,
654 0.9,
655 );
656 add(
657 r#"find_by_sql\s*\(\s*["']([^"']+)["']"#,
658 SqlContext::ActiveRecord,
659 &["rb"],
660 1,
661 0.9,
662 );
663
664 add(
667 r#"["'`]((?:SELECT|INSERT|UPDATE|DELETE|CREATE|ALTER|DROP|TRUNCATE)\s+[^"'`]{10,})["'`]"#,
668 SqlContext::RawSql,
669 &["*"],
670 1,
671 0.7,
672 );
673
674 patterns
675 }
676
677 pub fn extract_from_file(&self, path: &Path) -> Vec<ExtractedSql> {
679 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
680
681 let content = match std::fs::read_to_string(path) {
682 Ok(c) => c,
683 Err(_) => return vec![],
684 };
685
686 let path_str = path.to_string_lossy().to_string();
687 let mut results = Vec::new();
688
689 for (line_idx, line) in content.lines().enumerate() {
690 for pattern in &self.patterns {
691 if !pattern.extensions.contains(&"*") && !pattern.extensions.contains(&ext) {
693 continue;
694 }
695
696 for cap in pattern.regex.captures_iter(line) {
697 let sql = if pattern.capture_group > 0 {
698 cap.get(pattern.capture_group)
699 .map(|m| m.as_str())
700 .unwrap_or("")
701 } else {
702 cap.get(0).map(|m| m.as_str()).unwrap_or("")
703 };
704
705 let sql = sql.trim().to_string();
706
707 if sql.len() < 5 {
709 continue;
710 }
711
712 if !Self::looks_like_sql(&sql) {
714 continue;
715 }
716
717 results.push(ExtractedSql {
718 source_file: path_str.clone(),
719 line: line_idx + 1,
720 column: cap.get(1).map(|m| m.start()),
721 sql,
722 context: pattern.context.clone(),
723 confidence: pattern.confidence,
724 });
725 }
726 }
727 }
728
729 results
730 }
731
732 fn looks_like_sql(s: &str) -> bool {
734 let upper = s.to_uppercase();
735 let sql_keywords = [
736 "SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "ALTER", "DROP", "TRUNCATE", "FROM",
737 "WHERE", "JOIN", "TABLE", "INDEX", "COLUMN",
738 ];
739 sql_keywords.iter().any(|kw| upper.contains(kw))
740 }
741
742 fn is_dangerous_sql(sql: &str) -> bool {
744 let upper = sql.to_uppercase();
745 upper.contains("DROP ")
746 || upper.contains("TRUNCATE ")
747 || upper.contains("DELETE ")
748 || upper.contains("ALTER ")
749 || upper.contains("CREATE INDEX")
750 }
751
752 pub fn scan_directory(&self, root: &Path) -> SqlExtractionReport {
754 let files = collect_source_files(root);
755 let total = files.len();
756
757 let extracted: Vec<ExtractedSql> = files
758 .par_iter()
759 .flat_map(|path| self.extract_from_file(path))
760 .collect();
761
762 let dangerous: Vec<ExtractedSql> = extracted
764 .iter()
765 .filter(|e| Self::is_dangerous_sql(&e.sql))
766 .cloned()
767 .collect();
768
769 let mut by_context: HashMap<String, usize> = HashMap::new();
771 for e in &extracted {
772 *by_context.entry(e.context.to_string()).or_insert(0) += 1;
773 }
774
775 SqlExtractionReport {
776 files_scanned: total,
777 extracted,
778 dangerous,
779 by_context,
780 }
781 }
782}
783
784#[cfg(test)]
789mod tests {
790 use super::*;
791
792 #[test]
793 fn test_sql_extractor_prisma() {
794 let extractor = SqlExtractor::new();
795 let code =
796 r#"const result = await prisma.$queryRaw`SELECT * FROM users WHERE id = ${id}`;"#;
797
798 let temp_dir = std::env::temp_dir().join("schema-risk-test-prisma");
800 let _ = std::fs::create_dir_all(&temp_dir);
801 let file_path = temp_dir.join("test.ts");
802 std::fs::write(&file_path, code).unwrap();
803
804 let results = extractor.extract_from_file(&file_path);
805 assert!(!results.is_empty());
806 assert_eq!(results[0].context, SqlContext::PrismaRaw);
807 assert!(results[0].sql.contains("SELECT"));
808
809 let _ = std::fs::remove_dir_all(&temp_dir);
810 }
811
812 #[test]
813 fn test_sql_extractor_raw_sql() {
814 let extractor = SqlExtractor::new();
815 let code = r#"const query = "SELECT * FROM users WHERE active = true";"#;
816
817 let temp_dir = std::env::temp_dir().join("schema-risk-test-raw");
818 let _ = std::fs::create_dir_all(&temp_dir);
819 let file_path = temp_dir.join("test.js");
820 std::fs::write(&file_path, code).unwrap();
821
822 let results = extractor.extract_from_file(&file_path);
823 assert!(!results.is_empty());
824 assert!(results[0].sql.contains("SELECT"));
825
826 let _ = std::fs::remove_dir_all(&temp_dir);
827 }
828
829 #[test]
830 fn test_dangerous_sql_detection() {
831 assert!(SqlExtractor::is_dangerous_sql("DROP TABLE users"));
832 assert!(SqlExtractor::is_dangerous_sql(
833 "DELETE FROM users WHERE id = 1"
834 ));
835 assert!(SqlExtractor::is_dangerous_sql("TRUNCATE TABLE sessions"));
836 assert!(SqlExtractor::is_dangerous_sql(
837 "ALTER TABLE users ADD COLUMN age INT"
838 ));
839 assert!(!SqlExtractor::is_dangerous_sql("SELECT * FROM users"));
840 assert!(!SqlExtractor::is_dangerous_sql(
841 "INSERT INTO users (name) VALUES ('test')"
842 ));
843 }
844
845 #[test]
846 fn test_looks_like_sql() {
847 assert!(SqlExtractor::looks_like_sql("SELECT * FROM users"));
848 assert!(SqlExtractor::looks_like_sql(
849 "INSERT INTO users (name) VALUES ('test')"
850 ));
851 assert!(SqlExtractor::looks_like_sql("DROP TABLE users"));
852 assert!(!SqlExtractor::looks_like_sql("Hello world"));
853 assert!(!SqlExtractor::looks_like_sql("const x = 5"));
854 }
855}