sql_insight/extractor/
crud_table_extractor.rs

1//! A Extractor that extracts CRUD tables from SQL queries.
2//!
3//! See [`extract_crud_tables`](crate::extract_crud_tables()) as the entry point for extracting CRUD tables from SQL.
4
5use std::fmt;
6use std::ops::ControlFlow;
7
8use crate::error::Error;
9use crate::extractor::table_extractor::TableReference;
10use crate::{helper, TableExtractor};
11use sqlparser::ast::{MergeClause, Statement, Visit, Visitor};
12use sqlparser::dialect::Dialect;
13use sqlparser::parser::Parser;
14
15/// Convenience function to extract CRUD tables from SQL.
16///
17/// ## Example
18///
19/// ```rust
20/// use sql_insight::sqlparser::dialect::GenericDialect;
21///
22/// let dialect = GenericDialect {};
23/// let sql = "INSERT INTO t1 (a) SELECT a FROM t2";
24/// let result = sql_insight::extract_crud_tables(&dialect, sql).unwrap();
25/// println!("{:#?}", result);
26/// assert_eq!(result[0].as_ref().unwrap().to_string(), "Create: [t1], Read: [t2], Update: [], Delete: []");
27/// ```
28pub fn extract_crud_tables(
29    dialect: &dyn Dialect,
30    sql: &str,
31) -> Result<Vec<Result<CrudTables, Error>>, Error> {
32    CrudTableExtractor::extract(dialect, sql)
33}
34
35/// [`CrudTables`] represents the tables involved in CRUD operations.
36#[derive(Default, Debug, PartialEq)]
37pub struct CrudTables {
38    pub create_tables: Vec<TableReference>,
39    pub read_tables: Vec<TableReference>,
40    pub update_tables: Vec<TableReference>,
41    pub delete_tables: Vec<TableReference>,
42}
43
44impl fmt::Display for CrudTables {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        let create_tables = self.format_tables(&self.create_tables);
47        let read_tables = self.format_tables(&self.read_tables);
48        let update_tables = self.format_tables(&self.update_tables);
49        let delete_tables = self.format_tables(&self.delete_tables);
50        write!(
51            f,
52            "Create: [{}], Read: [{}], Update: [{}], Delete: [{}]",
53            create_tables, read_tables, update_tables, delete_tables
54        )
55    }
56}
57
58impl CrudTables {
59    fn format_tables(&self, tables: &[TableReference]) -> String {
60        tables
61            .iter()
62            .map(|t| t.to_string())
63            .collect::<Vec<String>>()
64            .join(", ")
65    }
66}
67
68/// A visitor to extract CRUD tables from SQL.
69#[derive(Default, Debug)]
70pub struct CrudTableExtractor {
71    create_tables: Vec<TableReference>,
72    read_tables: Vec<TableReference>,
73    update_tables: Vec<TableReference>,
74    delete_tables: Vec<TableReference>,
75    possibly_aliased_delete_tables: Vec<TableReference>,
76}
77
78impl Visitor for CrudTableExtractor {
79    type Break = Error;
80
81    fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
82        match statement {
83            Statement::Insert { table_name, .. } => {
84                match TableReference::try_from(table_name) {
85                    Ok(table) => self.create_tables.push(table),
86                    Err(e) => return ControlFlow::Break(e),
87                }
88                self.read_tables = helper::calc_difference_of_tables(
89                    self.read_tables.clone(),
90                    self.create_tables.clone(),
91                );
92            }
93            Statement::Update { table, .. } => {
94                match TableExtractor::extract_from_table_node(table) {
95                    Ok(tables) => tables
96                        .0
97                        .into_iter()
98                        .for_each(|table| self.update_tables.push(table)),
99                    Err(e) => return ControlFlow::Break(e),
100                }
101                self.read_tables = helper::calc_difference_of_tables(
102                    self.read_tables.clone(),
103                    self.update_tables.clone(),
104                );
105            }
106            Statement::Delete { tables, from, .. } => {
107                // When tables are present, deletion sqls are these tables,
108                // and from clause is used as a data source.
109                if !tables.is_empty() {
110                    for table in tables {
111                        match TableReference::try_from(table) {
112                            Ok(table) => self.possibly_aliased_delete_tables.push(table),
113                            Err(e) => return ControlFlow::Break(e),
114                        }
115                    }
116                } else {
117                    for table_with_join in from {
118                        match TableExtractor::extract_from_table_node(table_with_join) {
119                            Ok(tables) => tables
120                                .0
121                                .into_iter()
122                                .for_each(|table| self.possibly_aliased_delete_tables.push(table)),
123                            Err(e) => return ControlFlow::Break(e),
124                        }
125                    }
126                }
127                self.delete_tables = helper::resolve_aliased_tables(
128                    self.possibly_aliased_delete_tables.clone(),
129                    self.read_tables.clone(),
130                );
131                self.read_tables = helper::calc_difference_of_tables(
132                    self.read_tables.clone(),
133                    self.delete_tables.clone(),
134                );
135            }
136            Statement::Merge { table, clauses, .. } => {
137                let target_table = match TableReference::try_from(table) {
138                    Ok(table) => table,
139                    Err(e) => return ControlFlow::Break(e),
140                };
141                let (mut inserted, mut updated, mut deleted) = (false, false, false);
142                clauses.iter().for_each(|clause| match clause {
143                    MergeClause::MatchedUpdate { .. } => updated = true,
144                    MergeClause::MatchedDelete { .. } => deleted = true,
145                    MergeClause::NotMatched { .. } => inserted = true,
146                });
147                if inserted {
148                    self.create_tables.push(target_table.clone());
149                }
150                if updated {
151                    self.update_tables.push(target_table.clone());
152                }
153                if deleted {
154                    self.delete_tables.push(target_table.clone());
155                }
156                self.read_tables =
157                    helper::calc_difference_of_tables(self.read_tables.clone(), vec![target_table]);
158            }
159            _ => {}
160        }
161        ControlFlow::Continue(())
162    }
163}
164
165impl CrudTableExtractor {
166    /// Extract CRUD tables from SQL.
167    pub fn extract(
168        dialect: &dyn Dialect,
169        sql: &str,
170    ) -> Result<Vec<Result<CrudTables, Error>>, Error> {
171        let statements = Parser::parse_sql(dialect, sql)?;
172        let results = statements
173            .iter()
174            .map(Self::extract_from_statement)
175            .collect::<Vec<Result<CrudTables, Error>>>();
176        Ok(results)
177    }
178
179    fn extract_from_statement(statement: &Statement) -> Result<CrudTables, Error> {
180        let mut visitor = CrudTableExtractor {
181            read_tables: TableExtractor::extract_from_statement(statement)?.0,
182            ..Default::default()
183        };
184        match statement.visit(&mut visitor) {
185            ControlFlow::Break(e) => Err(e),
186            ControlFlow::Continue(()) => Ok(CrudTables {
187                create_tables: visitor.create_tables,
188                read_tables: visitor.read_tables,
189                update_tables: visitor.update_tables,
190                delete_tables: visitor.delete_tables,
191            }),
192        }
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use crate::test_utils::all_dialects;
200    use sqlparser::dialect::MySqlDialect;
201
202    fn assert_crud_table_extraction(
203        sql: &str,
204        expected: Vec<Result<CrudTables, Error>>,
205        dialects: Vec<Box<dyn Dialect>>,
206    ) {
207        for dialect in dialects {
208            let result = CrudTableExtractor::extract(dialect.as_ref(), sql).unwrap();
209            assert_eq!(result, expected, "Failed for dialect: {dialect:?}")
210        }
211    }
212
213    #[test]
214    fn test_single_statement() {
215        let sql = "SELECT a FROM t1";
216        let expected = vec![Ok(CrudTables {
217            create_tables: vec![],
218            read_tables: vec![TableReference {
219                catalog: None,
220                schema: None,
221                name: "t1".into(),
222                alias: None,
223            }],
224            update_tables: vec![],
225            delete_tables: vec![],
226        })];
227        assert_crud_table_extraction(sql, expected, all_dialects());
228    }
229
230    #[test]
231    fn test_multiple_statements() {
232        let sql = "SELECT a FROM t1; SELECT b FROM t2";
233        let expected = vec![
234            Ok(CrudTables {
235                create_tables: vec![],
236                read_tables: vec![TableReference {
237                    catalog: None,
238                    schema: None,
239                    name: "t1".into(),
240                    alias: None,
241                }],
242                update_tables: vec![],
243                delete_tables: vec![],
244            }),
245            Ok(CrudTables {
246                create_tables: vec![],
247                read_tables: vec![TableReference {
248                    catalog: None,
249                    schema: None,
250                    name: "t2".into(),
251                    alias: None,
252                }],
253                update_tables: vec![],
254                delete_tables: vec![],
255            }),
256        ];
257        assert_crud_table_extraction(sql, expected, all_dialects());
258    }
259
260    #[test]
261    fn test_statement_with_alias() {
262        let sql = "SELECT a FROM t1 AS t1_alias";
263        let expected = vec![Ok(CrudTables {
264            create_tables: vec![],
265            read_tables: vec![TableReference {
266                catalog: None,
267                schema: None,
268                name: "t1".into(),
269                alias: Some("t1_alias".into()),
270            }],
271            update_tables: vec![],
272            delete_tables: vec![],
273        })];
274        assert_crud_table_extraction(sql, expected, all_dialects());
275    }
276
277    #[test]
278    fn test_statement_with_table_identifier() {
279        let sql = "SELECT a FROM catalog.schema.table";
280        let expected = vec![Ok(CrudTables {
281            create_tables: vec![],
282            read_tables: vec![TableReference {
283                catalog: Some("catalog".into()),
284                schema: Some("schema".into()),
285                name: "table".into(),
286                alias: None,
287            }],
288            update_tables: vec![],
289            delete_tables: vec![],
290        })];
291        assert_crud_table_extraction(sql, expected, all_dialects());
292    }
293
294    #[test]
295    fn test_statement_with_table_identifier_and_alias() {
296        let sql = "SELECT a FROM catalog.schema.table AS table_alias";
297        let expected = vec![Ok(CrudTables {
298            create_tables: vec![],
299            read_tables: vec![TableReference {
300                catalog: Some("catalog".into()),
301                schema: Some("schema".into()),
302                name: "table".into(),
303                alias: Some("table_alias".into()),
304            }],
305            update_tables: vec![],
306            delete_tables: vec![],
307        })];
308        assert_crud_table_extraction(sql, expected, all_dialects());
309    }
310
311    #[test]
312    fn test_statement_error_with_too_many_identifiers() {
313        let sql = "INSERT INTO catalog.schema.table.extra (a) VALUES (1)";
314        let expected = vec![Err(Error::AnalysisError(
315            "Too many identifiers provided".to_string(),
316        ))];
317        assert_crud_table_extraction(sql, expected, all_dialects());
318    }
319
320    mod delete_statement {
321        use super::*;
322
323        #[test]
324        fn test_delete_statement() {
325            let sql = "DELETE FROM t1";
326            let expected = vec![Ok(CrudTables {
327                create_tables: vec![],
328                read_tables: vec![],
329                update_tables: vec![],
330                delete_tables: vec![TableReference {
331                    catalog: None,
332                    schema: None,
333                    name: "t1".into(),
334                    alias: None,
335                }],
336            })];
337            assert_crud_table_extraction(sql, expected, all_dialects());
338        }
339
340        #[test]
341        fn test_delete_statement_with_table_identifier() {
342            let sql = "DELETE FROM catalog.schema.t1";
343            let expected = vec![Ok(CrudTables {
344                create_tables: vec![],
345                read_tables: vec![],
346                update_tables: vec![],
347                delete_tables: vec![TableReference {
348                    catalog: Some("catalog".into()),
349                    schema: Some("schema".into()),
350                    name: "t1".into(),
351                    alias: None,
352                }],
353            })];
354            assert_crud_table_extraction(sql, expected, all_dialects());
355        }
356
357        #[test]
358        fn test_delete_statement_with_alias() {
359            let sql = "DELETE FROM t1 AS t1_alias";
360            let expected = vec![Ok(CrudTables {
361                create_tables: vec![],
362                read_tables: vec![],
363                update_tables: vec![],
364                delete_tables: vec![TableReference {
365                    catalog: None,
366                    schema: None,
367                    name: "t1".into(),
368                    alias: Some("t1_alias".into()),
369                }],
370            })];
371            assert_crud_table_extraction(sql, expected, all_dialects());
372        }
373
374        #[test]
375        fn test_delete_multiple_tables_syntax() {
376            let sql = "DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3";
377            let expected = vec![Ok(CrudTables {
378                create_tables: vec![],
379                read_tables: vec![
380                    TableReference {
381                        catalog: None,
382                        schema: None,
383                        name: "t1".into(),
384                        alias: None,
385                    },
386                    TableReference {
387                        catalog: None,
388                        schema: None,
389                        name: "t2".into(),
390                        alias: None,
391                    },
392                    TableReference {
393                        catalog: None,
394                        schema: None,
395                        name: "t3".into(),
396                        alias: None,
397                    },
398                ],
399                update_tables: vec![],
400                delete_tables: vec![
401                    TableReference {
402                        catalog: None,
403                        schema: None,
404                        name: "t1".into(),
405                        alias: None,
406                    },
407                    TableReference {
408                        catalog: None,
409                        schema: None,
410                        name: "t2".into(),
411                        alias: None,
412                    },
413                ],
414            })];
415            assert_crud_table_extraction(sql, expected, all_dialects());
416        }
417
418        #[test]
419        fn test_delete_multiple_tables_syntax_with_alias() {
420            let sql =
421                "DELETE t1_alias, t2_alias FROM t1 AS t1_alias INNER JOIN t2 AS t2_alias INNER JOIN t3";
422            let expected = vec![Ok(CrudTables {
423                create_tables: vec![],
424                read_tables: vec![
425                    TableReference {
426                        catalog: None,
427                        schema: None,
428                        name: "t1".into(),
429                        alias: Some("t1_alias".into()),
430                    },
431                    TableReference {
432                        catalog: None,
433                        schema: None,
434                        name: "t2".into(),
435                        alias: Some("t2_alias".into()),
436                    },
437                    TableReference {
438                        catalog: None,
439                        schema: None,
440                        name: "t3".into(),
441                        alias: None,
442                    },
443                ],
444                update_tables: vec![],
445                delete_tables: vec![
446                    TableReference {
447                        catalog: None,
448                        schema: None,
449                        name: "t1".into(),
450                        alias: Some("t1_alias".into()),
451                    },
452                    TableReference {
453                        catalog: None,
454                        schema: None,
455                        name: "t2".into(),
456                        alias: Some("t2_alias".into()),
457                    },
458                ],
459            })];
460            assert_crud_table_extraction(sql, expected, all_dialects());
461        }
462
463        #[test]
464        fn test_delete_multiple_tables_syntax_with_using() {
465            let sql = "DELETE FROM t1, t2 USING t1 INNER JOIN t2 INNER JOIN t3";
466            let expected = vec![Ok(CrudTables {
467                create_tables: vec![],
468                read_tables: vec![
469                    TableReference {
470                        catalog: None,
471                        schema: None,
472                        name: "t1".into(),
473                        alias: None,
474                    },
475                    TableReference {
476                        catalog: None,
477                        schema: None,
478                        name: "t2".into(),
479                        alias: None,
480                    },
481                    TableReference {
482                        catalog: None,
483                        schema: None,
484                        name: "t3".into(),
485                        alias: None,
486                    },
487                ],
488                update_tables: vec![],
489                delete_tables: vec![
490                    TableReference {
491                        catalog: None,
492                        schema: None,
493                        name: "t1".into(),
494                        alias: None,
495                    },
496                    TableReference {
497                        catalog: None,
498                        schema: None,
499                        name: "t2".into(),
500                        alias: None,
501                    },
502                ],
503            })];
504            assert_crud_table_extraction(sql, expected, all_dialects());
505        }
506
507        #[test]
508        fn test_delete_multiple_tables_syntax_with_using_with_alias() {
509            let sql = "DELETE FROM t1_alias, t2_alias USING t1 AS t1_alias INNER JOIN t2 AS t2_alias INNER JOIN t3";
510            let expected = vec![Ok(CrudTables {
511                create_tables: vec![],
512                read_tables: vec![
513                    TableReference {
514                        catalog: None,
515                        schema: None,
516                        name: "t1".into(),
517                        alias: Some("t1_alias".into()),
518                    },
519                    TableReference {
520                        catalog: None,
521                        schema: None,
522                        name: "t2".into(),
523                        alias: Some("t2_alias".into()),
524                    },
525                    TableReference {
526                        catalog: None,
527                        schema: None,
528                        name: "t3".into(),
529                        alias: None,
530                    },
531                ],
532                update_tables: vec![],
533                delete_tables: vec![
534                    TableReference {
535                        catalog: None,
536                        schema: None,
537                        name: "t1".into(),
538                        alias: Some("t1_alias".into()),
539                    },
540                    TableReference {
541                        catalog: None,
542                        schema: None,
543                        name: "t2".into(),
544                        alias: Some("t2_alias".into()),
545                    },
546                ],
547            })];
548            assert_crud_table_extraction(sql, expected, all_dialects());
549        }
550    }
551
552    mod insert_statement {
553        use super::*;
554
555        #[test]
556        fn test_insert_statement() {
557            let sql = "INSERT INTO t1 (a) VALUES (1)";
558            let expected = vec![Ok(CrudTables {
559                create_tables: vec![TableReference {
560                    catalog: None,
561                    schema: None,
562                    name: "t1".into(),
563                    alias: None,
564                }],
565                read_tables: vec![],
566                update_tables: vec![],
567                delete_tables: vec![],
568            })];
569            assert_crud_table_extraction(sql, expected, all_dialects());
570        }
571
572        #[test]
573        fn test_insert_select_statement() {
574            let sql = "INSERT INTO t1 (a) SELECT a FROM t2 AS t2_alias INNER JOIN t3 USING (id)";
575            let expected = vec![Ok(CrudTables {
576                create_tables: vec![TableReference {
577                    catalog: None,
578                    schema: None,
579                    name: "t1".into(),
580                    alias: None,
581                }],
582                read_tables: vec![
583                    TableReference {
584                        catalog: None,
585                        schema: None,
586                        name: "t2".into(),
587                        alias: Some("t2_alias".into()),
588                    },
589                    TableReference {
590                        catalog: None,
591                        schema: None,
592                        name: "t3".into(),
593                        alias: None,
594                    },
595                ],
596                update_tables: vec![],
597                delete_tables: vec![],
598            })];
599            assert_crud_table_extraction(sql, expected, all_dialects());
600        }
601    }
602
603    mod update_statemnet {
604        use super::*;
605
606        #[test]
607        fn test_update_statement() {
608            let sql = "UPDATE t1 SET a=1";
609            let result = CrudTableExtractor::extract(&MySqlDialect {}, sql).unwrap();
610            assert_eq!(
611                result,
612                vec![Ok(CrudTables {
613                    create_tables: vec![],
614                    read_tables: vec![],
615                    update_tables: vec![TableReference {
616                        catalog: None,
617                        schema: None,
618                        name: "t1".into(),
619                        alias: None,
620                    }],
621                    delete_tables: vec![],
622                }),]
623            )
624        }
625
626        #[test]
627        fn test_update_statement_with_alias() {
628            let sql = "UPDATE t1 AS t1_alias INNER JOIN t2 ON t1_alias.a = t2.a SET t1_alias.b = t2.b WHERE t2.c = (SELECT c FROM t3)";
629            let expected = vec![Ok(CrudTables {
630                create_tables: vec![],
631                read_tables: vec![TableReference {
632                    catalog: None,
633                    schema: None,
634                    name: "t3".into(),
635                    alias: None,
636                }],
637                update_tables: vec![
638                    TableReference {
639                        catalog: None,
640                        schema: None,
641                        name: "t1".into(),
642                        alias: Some("t1_alias".into()),
643                    },
644                    TableReference {
645                        catalog: None,
646                        schema: None,
647                        name: "t2".into(),
648                        alias: None,
649                    },
650                ],
651                delete_tables: vec![],
652            })];
653            assert_crud_table_extraction(sql, expected, all_dialects());
654        }
655    }
656
657    #[test]
658    fn test_merge_statement() {
659        let sql = "MERGE INTO t1 AS t1_alias USING t2 AS t2_alias ON t1_alias.a = t2_alias.a \
660                         WHEN MATCHED AND t2_alias.b = 1 THEN DELETE \
661                         WHEN MATCHED AND t2_alias.b = 2 THEN UPDATE SET t1_alias.b = t2_alias.b \
662                         WHEN NOT MATCHED THEN INSERT (a, b) VALUES (t2_alias.a, t2_alias.b)";
663        let expected = vec![Ok(CrudTables {
664            create_tables: vec![TableReference {
665                catalog: None,
666                schema: None,
667                name: "t1".into(),
668                alias: Some("t1_alias".into()),
669            }],
670            read_tables: vec![TableReference {
671                catalog: None,
672                schema: None,
673                name: "t2".into(),
674                alias: Some("t2_alias".into()),
675            }],
676            update_tables: vec![TableReference {
677                catalog: None,
678                schema: None,
679                name: "t1".into(),
680                alias: Some("t1_alias".into()),
681            }],
682            delete_tables: vec![TableReference {
683                catalog: None,
684                schema: None,
685                name: "t1".into(),
686                alias: Some("t1_alias".into()),
687            }],
688        })];
689        assert_crud_table_extraction(sql, expected, all_dialects());
690    }
691
692    #[test]
693    fn test_create_table_statement() {
694        let sql = "CREATE TABLE t1 (a INT)";
695        let expected = vec![Ok(CrudTables {
696            create_tables: vec![],
697            read_tables: vec![TableReference {
698                catalog: None,
699                schema: None,
700                name: "t1".into(),
701                alias: None,
702            }],
703            update_tables: vec![],
704            delete_tables: vec![],
705        })];
706        assert_crud_table_extraction(sql, expected, all_dialects());
707    }
708
709    #[test]
710    fn test_alters_table_statement() {
711        let sql = "ALTER TABLE t1 ADD COLUMN a INT";
712        let expected = vec![Ok(CrudTables {
713            create_tables: vec![],
714            read_tables: vec![TableReference {
715                catalog: None,
716                schema: None,
717                name: "t1".into(),
718                alias: None,
719            }],
720            update_tables: vec![],
721            delete_tables: vec![],
722        })];
723        assert_crud_table_extraction(sql, expected, all_dialects());
724    }
725}