1use std::fmt;
6use std::ops::ControlFlow;
7
8use crate::error::Error;
9use crate::extractor::table_extractor::TableReference;
10use crate::{helper, TableExtractor};
11use sqlparser::ast::{MergeClause, Statement, Visit, Visitor};
12use sqlparser::dialect::Dialect;
13use sqlparser::parser::Parser;
14
15pub fn extract_crud_tables(
29 dialect: &dyn Dialect,
30 sql: &str,
31) -> Result<Vec<Result<CrudTables, Error>>, Error> {
32 CrudTableExtractor::extract(dialect, sql)
33}
34
35#[derive(Default, Debug, PartialEq)]
37pub struct CrudTables {
38 pub create_tables: Vec<TableReference>,
39 pub read_tables: Vec<TableReference>,
40 pub update_tables: Vec<TableReference>,
41 pub delete_tables: Vec<TableReference>,
42}
43
44impl fmt::Display for CrudTables {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 let create_tables = self.format_tables(&self.create_tables);
47 let read_tables = self.format_tables(&self.read_tables);
48 let update_tables = self.format_tables(&self.update_tables);
49 let delete_tables = self.format_tables(&self.delete_tables);
50 write!(
51 f,
52 "Create: [{}], Read: [{}], Update: [{}], Delete: [{}]",
53 create_tables, read_tables, update_tables, delete_tables
54 )
55 }
56}
57
58impl CrudTables {
59 fn format_tables(&self, tables: &[TableReference]) -> String {
60 tables
61 .iter()
62 .map(|t| t.to_string())
63 .collect::<Vec<String>>()
64 .join(", ")
65 }
66}
67
68#[derive(Default, Debug)]
70pub struct CrudTableExtractor {
71 create_tables: Vec<TableReference>,
72 read_tables: Vec<TableReference>,
73 update_tables: Vec<TableReference>,
74 delete_tables: Vec<TableReference>,
75 possibly_aliased_delete_tables: Vec<TableReference>,
76}
77
78impl Visitor for CrudTableExtractor {
79 type Break = Error;
80
81 fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
82 match statement {
83 Statement::Insert { table_name, .. } => {
84 match TableReference::try_from(table_name) {
85 Ok(table) => self.create_tables.push(table),
86 Err(e) => return ControlFlow::Break(e),
87 }
88 self.read_tables = helper::calc_difference_of_tables(
89 self.read_tables.clone(),
90 self.create_tables.clone(),
91 );
92 }
93 Statement::Update { table, .. } => {
94 match TableExtractor::extract_from_table_node(table) {
95 Ok(tables) => tables
96 .0
97 .into_iter()
98 .for_each(|table| self.update_tables.push(table)),
99 Err(e) => return ControlFlow::Break(e),
100 }
101 self.read_tables = helper::calc_difference_of_tables(
102 self.read_tables.clone(),
103 self.update_tables.clone(),
104 );
105 }
106 Statement::Delete { tables, from, .. } => {
107 if !tables.is_empty() {
110 for table in tables {
111 match TableReference::try_from(table) {
112 Ok(table) => self.possibly_aliased_delete_tables.push(table),
113 Err(e) => return ControlFlow::Break(e),
114 }
115 }
116 } else {
117 for table_with_join in from {
118 match TableExtractor::extract_from_table_node(table_with_join) {
119 Ok(tables) => tables
120 .0
121 .into_iter()
122 .for_each(|table| self.possibly_aliased_delete_tables.push(table)),
123 Err(e) => return ControlFlow::Break(e),
124 }
125 }
126 }
127 self.delete_tables = helper::resolve_aliased_tables(
128 self.possibly_aliased_delete_tables.clone(),
129 self.read_tables.clone(),
130 );
131 self.read_tables = helper::calc_difference_of_tables(
132 self.read_tables.clone(),
133 self.delete_tables.clone(),
134 );
135 }
136 Statement::Merge { table, clauses, .. } => {
137 let target_table = match TableReference::try_from(table) {
138 Ok(table) => table,
139 Err(e) => return ControlFlow::Break(e),
140 };
141 let (mut inserted, mut updated, mut deleted) = (false, false, false);
142 clauses.iter().for_each(|clause| match clause {
143 MergeClause::MatchedUpdate { .. } => updated = true,
144 MergeClause::MatchedDelete { .. } => deleted = true,
145 MergeClause::NotMatched { .. } => inserted = true,
146 });
147 if inserted {
148 self.create_tables.push(target_table.clone());
149 }
150 if updated {
151 self.update_tables.push(target_table.clone());
152 }
153 if deleted {
154 self.delete_tables.push(target_table.clone());
155 }
156 self.read_tables =
157 helper::calc_difference_of_tables(self.read_tables.clone(), vec![target_table]);
158 }
159 _ => {}
160 }
161 ControlFlow::Continue(())
162 }
163}
164
165impl CrudTableExtractor {
166 pub fn extract(
168 dialect: &dyn Dialect,
169 sql: &str,
170 ) -> Result<Vec<Result<CrudTables, Error>>, Error> {
171 let statements = Parser::parse_sql(dialect, sql)?;
172 let results = statements
173 .iter()
174 .map(Self::extract_from_statement)
175 .collect::<Vec<Result<CrudTables, Error>>>();
176 Ok(results)
177 }
178
179 fn extract_from_statement(statement: &Statement) -> Result<CrudTables, Error> {
180 let mut visitor = CrudTableExtractor {
181 read_tables: TableExtractor::extract_from_statement(statement)?.0,
182 ..Default::default()
183 };
184 match statement.visit(&mut visitor) {
185 ControlFlow::Break(e) => Err(e),
186 ControlFlow::Continue(()) => Ok(CrudTables {
187 create_tables: visitor.create_tables,
188 read_tables: visitor.read_tables,
189 update_tables: visitor.update_tables,
190 delete_tables: visitor.delete_tables,
191 }),
192 }
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use crate::test_utils::all_dialects;
200 use sqlparser::dialect::MySqlDialect;
201
202 fn assert_crud_table_extraction(
203 sql: &str,
204 expected: Vec<Result<CrudTables, Error>>,
205 dialects: Vec<Box<dyn Dialect>>,
206 ) {
207 for dialect in dialects {
208 let result = CrudTableExtractor::extract(dialect.as_ref(), sql).unwrap();
209 assert_eq!(result, expected, "Failed for dialect: {dialect:?}")
210 }
211 }
212
213 #[test]
214 fn test_single_statement() {
215 let sql = "SELECT a FROM t1";
216 let expected = vec![Ok(CrudTables {
217 create_tables: vec![],
218 read_tables: vec![TableReference {
219 catalog: None,
220 schema: None,
221 name: "t1".into(),
222 alias: None,
223 }],
224 update_tables: vec![],
225 delete_tables: vec![],
226 })];
227 assert_crud_table_extraction(sql, expected, all_dialects());
228 }
229
230 #[test]
231 fn test_multiple_statements() {
232 let sql = "SELECT a FROM t1; SELECT b FROM t2";
233 let expected = vec![
234 Ok(CrudTables {
235 create_tables: vec![],
236 read_tables: vec![TableReference {
237 catalog: None,
238 schema: None,
239 name: "t1".into(),
240 alias: None,
241 }],
242 update_tables: vec![],
243 delete_tables: vec![],
244 }),
245 Ok(CrudTables {
246 create_tables: vec![],
247 read_tables: vec![TableReference {
248 catalog: None,
249 schema: None,
250 name: "t2".into(),
251 alias: None,
252 }],
253 update_tables: vec![],
254 delete_tables: vec![],
255 }),
256 ];
257 assert_crud_table_extraction(sql, expected, all_dialects());
258 }
259
260 #[test]
261 fn test_statement_with_alias() {
262 let sql = "SELECT a FROM t1 AS t1_alias";
263 let expected = vec![Ok(CrudTables {
264 create_tables: vec![],
265 read_tables: vec![TableReference {
266 catalog: None,
267 schema: None,
268 name: "t1".into(),
269 alias: Some("t1_alias".into()),
270 }],
271 update_tables: vec![],
272 delete_tables: vec![],
273 })];
274 assert_crud_table_extraction(sql, expected, all_dialects());
275 }
276
277 #[test]
278 fn test_statement_with_table_identifier() {
279 let sql = "SELECT a FROM catalog.schema.table";
280 let expected = vec![Ok(CrudTables {
281 create_tables: vec![],
282 read_tables: vec![TableReference {
283 catalog: Some("catalog".into()),
284 schema: Some("schema".into()),
285 name: "table".into(),
286 alias: None,
287 }],
288 update_tables: vec![],
289 delete_tables: vec![],
290 })];
291 assert_crud_table_extraction(sql, expected, all_dialects());
292 }
293
294 #[test]
295 fn test_statement_with_table_identifier_and_alias() {
296 let sql = "SELECT a FROM catalog.schema.table AS table_alias";
297 let expected = vec![Ok(CrudTables {
298 create_tables: vec![],
299 read_tables: vec![TableReference {
300 catalog: Some("catalog".into()),
301 schema: Some("schema".into()),
302 name: "table".into(),
303 alias: Some("table_alias".into()),
304 }],
305 update_tables: vec![],
306 delete_tables: vec![],
307 })];
308 assert_crud_table_extraction(sql, expected, all_dialects());
309 }
310
311 #[test]
312 fn test_statement_error_with_too_many_identifiers() {
313 let sql = "INSERT INTO catalog.schema.table.extra (a) VALUES (1)";
314 let expected = vec![Err(Error::AnalysisError(
315 "Too many identifiers provided".to_string(),
316 ))];
317 assert_crud_table_extraction(sql, expected, all_dialects());
318 }
319
320 mod delete_statement {
321 use super::*;
322
323 #[test]
324 fn test_delete_statement() {
325 let sql = "DELETE FROM t1";
326 let expected = vec![Ok(CrudTables {
327 create_tables: vec![],
328 read_tables: vec![],
329 update_tables: vec![],
330 delete_tables: vec![TableReference {
331 catalog: None,
332 schema: None,
333 name: "t1".into(),
334 alias: None,
335 }],
336 })];
337 assert_crud_table_extraction(sql, expected, all_dialects());
338 }
339
340 #[test]
341 fn test_delete_statement_with_table_identifier() {
342 let sql = "DELETE FROM catalog.schema.t1";
343 let expected = vec![Ok(CrudTables {
344 create_tables: vec![],
345 read_tables: vec![],
346 update_tables: vec![],
347 delete_tables: vec![TableReference {
348 catalog: Some("catalog".into()),
349 schema: Some("schema".into()),
350 name: "t1".into(),
351 alias: None,
352 }],
353 })];
354 assert_crud_table_extraction(sql, expected, all_dialects());
355 }
356
357 #[test]
358 fn test_delete_statement_with_alias() {
359 let sql = "DELETE FROM t1 AS t1_alias";
360 let expected = vec![Ok(CrudTables {
361 create_tables: vec![],
362 read_tables: vec![],
363 update_tables: vec![],
364 delete_tables: vec![TableReference {
365 catalog: None,
366 schema: None,
367 name: "t1".into(),
368 alias: Some("t1_alias".into()),
369 }],
370 })];
371 assert_crud_table_extraction(sql, expected, all_dialects());
372 }
373
374 #[test]
375 fn test_delete_multiple_tables_syntax() {
376 let sql = "DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3";
377 let expected = vec![Ok(CrudTables {
378 create_tables: vec![],
379 read_tables: vec![
380 TableReference {
381 catalog: None,
382 schema: None,
383 name: "t1".into(),
384 alias: None,
385 },
386 TableReference {
387 catalog: None,
388 schema: None,
389 name: "t2".into(),
390 alias: None,
391 },
392 TableReference {
393 catalog: None,
394 schema: None,
395 name: "t3".into(),
396 alias: None,
397 },
398 ],
399 update_tables: vec![],
400 delete_tables: vec![
401 TableReference {
402 catalog: None,
403 schema: None,
404 name: "t1".into(),
405 alias: None,
406 },
407 TableReference {
408 catalog: None,
409 schema: None,
410 name: "t2".into(),
411 alias: None,
412 },
413 ],
414 })];
415 assert_crud_table_extraction(sql, expected, all_dialects());
416 }
417
418 #[test]
419 fn test_delete_multiple_tables_syntax_with_alias() {
420 let sql =
421 "DELETE t1_alias, t2_alias FROM t1 AS t1_alias INNER JOIN t2 AS t2_alias INNER JOIN t3";
422 let expected = vec![Ok(CrudTables {
423 create_tables: vec![],
424 read_tables: vec![
425 TableReference {
426 catalog: None,
427 schema: None,
428 name: "t1".into(),
429 alias: Some("t1_alias".into()),
430 },
431 TableReference {
432 catalog: None,
433 schema: None,
434 name: "t2".into(),
435 alias: Some("t2_alias".into()),
436 },
437 TableReference {
438 catalog: None,
439 schema: None,
440 name: "t3".into(),
441 alias: None,
442 },
443 ],
444 update_tables: vec![],
445 delete_tables: vec![
446 TableReference {
447 catalog: None,
448 schema: None,
449 name: "t1".into(),
450 alias: Some("t1_alias".into()),
451 },
452 TableReference {
453 catalog: None,
454 schema: None,
455 name: "t2".into(),
456 alias: Some("t2_alias".into()),
457 },
458 ],
459 })];
460 assert_crud_table_extraction(sql, expected, all_dialects());
461 }
462
463 #[test]
464 fn test_delete_multiple_tables_syntax_with_using() {
465 let sql = "DELETE FROM t1, t2 USING t1 INNER JOIN t2 INNER JOIN t3";
466 let expected = vec![Ok(CrudTables {
467 create_tables: vec![],
468 read_tables: vec![
469 TableReference {
470 catalog: None,
471 schema: None,
472 name: "t1".into(),
473 alias: None,
474 },
475 TableReference {
476 catalog: None,
477 schema: None,
478 name: "t2".into(),
479 alias: None,
480 },
481 TableReference {
482 catalog: None,
483 schema: None,
484 name: "t3".into(),
485 alias: None,
486 },
487 ],
488 update_tables: vec![],
489 delete_tables: vec![
490 TableReference {
491 catalog: None,
492 schema: None,
493 name: "t1".into(),
494 alias: None,
495 },
496 TableReference {
497 catalog: None,
498 schema: None,
499 name: "t2".into(),
500 alias: None,
501 },
502 ],
503 })];
504 assert_crud_table_extraction(sql, expected, all_dialects());
505 }
506
507 #[test]
508 fn test_delete_multiple_tables_syntax_with_using_with_alias() {
509 let sql = "DELETE FROM t1_alias, t2_alias USING t1 AS t1_alias INNER JOIN t2 AS t2_alias INNER JOIN t3";
510 let expected = vec![Ok(CrudTables {
511 create_tables: vec![],
512 read_tables: vec![
513 TableReference {
514 catalog: None,
515 schema: None,
516 name: "t1".into(),
517 alias: Some("t1_alias".into()),
518 },
519 TableReference {
520 catalog: None,
521 schema: None,
522 name: "t2".into(),
523 alias: Some("t2_alias".into()),
524 },
525 TableReference {
526 catalog: None,
527 schema: None,
528 name: "t3".into(),
529 alias: None,
530 },
531 ],
532 update_tables: vec![],
533 delete_tables: vec![
534 TableReference {
535 catalog: None,
536 schema: None,
537 name: "t1".into(),
538 alias: Some("t1_alias".into()),
539 },
540 TableReference {
541 catalog: None,
542 schema: None,
543 name: "t2".into(),
544 alias: Some("t2_alias".into()),
545 },
546 ],
547 })];
548 assert_crud_table_extraction(sql, expected, all_dialects());
549 }
550 }
551
552 mod insert_statement {
553 use super::*;
554
555 #[test]
556 fn test_insert_statement() {
557 let sql = "INSERT INTO t1 (a) VALUES (1)";
558 let expected = vec![Ok(CrudTables {
559 create_tables: vec![TableReference {
560 catalog: None,
561 schema: None,
562 name: "t1".into(),
563 alias: None,
564 }],
565 read_tables: vec![],
566 update_tables: vec![],
567 delete_tables: vec![],
568 })];
569 assert_crud_table_extraction(sql, expected, all_dialects());
570 }
571
572 #[test]
573 fn test_insert_select_statement() {
574 let sql = "INSERT INTO t1 (a) SELECT a FROM t2 AS t2_alias INNER JOIN t3 USING (id)";
575 let expected = vec![Ok(CrudTables {
576 create_tables: vec![TableReference {
577 catalog: None,
578 schema: None,
579 name: "t1".into(),
580 alias: None,
581 }],
582 read_tables: vec![
583 TableReference {
584 catalog: None,
585 schema: None,
586 name: "t2".into(),
587 alias: Some("t2_alias".into()),
588 },
589 TableReference {
590 catalog: None,
591 schema: None,
592 name: "t3".into(),
593 alias: None,
594 },
595 ],
596 update_tables: vec![],
597 delete_tables: vec![],
598 })];
599 assert_crud_table_extraction(sql, expected, all_dialects());
600 }
601 }
602
603 mod update_statemnet {
604 use super::*;
605
606 #[test]
607 fn test_update_statement() {
608 let sql = "UPDATE t1 SET a=1";
609 let result = CrudTableExtractor::extract(&MySqlDialect {}, sql).unwrap();
610 assert_eq!(
611 result,
612 vec![Ok(CrudTables {
613 create_tables: vec![],
614 read_tables: vec![],
615 update_tables: vec![TableReference {
616 catalog: None,
617 schema: None,
618 name: "t1".into(),
619 alias: None,
620 }],
621 delete_tables: vec![],
622 }),]
623 )
624 }
625
626 #[test]
627 fn test_update_statement_with_alias() {
628 let sql = "UPDATE t1 AS t1_alias INNER JOIN t2 ON t1_alias.a = t2.a SET t1_alias.b = t2.b WHERE t2.c = (SELECT c FROM t3)";
629 let expected = vec![Ok(CrudTables {
630 create_tables: vec![],
631 read_tables: vec![TableReference {
632 catalog: None,
633 schema: None,
634 name: "t3".into(),
635 alias: None,
636 }],
637 update_tables: vec![
638 TableReference {
639 catalog: None,
640 schema: None,
641 name: "t1".into(),
642 alias: Some("t1_alias".into()),
643 },
644 TableReference {
645 catalog: None,
646 schema: None,
647 name: "t2".into(),
648 alias: None,
649 },
650 ],
651 delete_tables: vec![],
652 })];
653 assert_crud_table_extraction(sql, expected, all_dialects());
654 }
655 }
656
657 #[test]
658 fn test_merge_statement() {
659 let sql = "MERGE INTO t1 AS t1_alias USING t2 AS t2_alias ON t1_alias.a = t2_alias.a \
660 WHEN MATCHED AND t2_alias.b = 1 THEN DELETE \
661 WHEN MATCHED AND t2_alias.b = 2 THEN UPDATE SET t1_alias.b = t2_alias.b \
662 WHEN NOT MATCHED THEN INSERT (a, b) VALUES (t2_alias.a, t2_alias.b)";
663 let expected = vec![Ok(CrudTables {
664 create_tables: vec![TableReference {
665 catalog: None,
666 schema: None,
667 name: "t1".into(),
668 alias: Some("t1_alias".into()),
669 }],
670 read_tables: vec![TableReference {
671 catalog: None,
672 schema: None,
673 name: "t2".into(),
674 alias: Some("t2_alias".into()),
675 }],
676 update_tables: vec![TableReference {
677 catalog: None,
678 schema: None,
679 name: "t1".into(),
680 alias: Some("t1_alias".into()),
681 }],
682 delete_tables: vec![TableReference {
683 catalog: None,
684 schema: None,
685 name: "t1".into(),
686 alias: Some("t1_alias".into()),
687 }],
688 })];
689 assert_crud_table_extraction(sql, expected, all_dialects());
690 }
691
692 #[test]
693 fn test_create_table_statement() {
694 let sql = "CREATE TABLE t1 (a INT)";
695 let expected = vec![Ok(CrudTables {
696 create_tables: vec![],
697 read_tables: vec![TableReference {
698 catalog: None,
699 schema: None,
700 name: "t1".into(),
701 alias: None,
702 }],
703 update_tables: vec![],
704 delete_tables: vec![],
705 })];
706 assert_crud_table_extraction(sql, expected, all_dialects());
707 }
708
709 #[test]
710 fn test_alters_table_statement() {
711 let sql = "ALTER TABLE t1 ADD COLUMN a INT";
712 let expected = vec![Ok(CrudTables {
713 create_tables: vec![],
714 read_tables: vec![TableReference {
715 catalog: None,
716 schema: None,
717 name: "t1".into(),
718 alias: None,
719 }],
720 update_tables: vec![],
721 delete_tables: vec![],
722 })];
723 assert_crud_table_extraction(sql, expected, all_dialects());
724 }
725}