sql_insight/extractor/
table_extractor.rs

1//! A Extractor that extracts tables from SQL queries.
2//!
3//! See [`extract_tables`](crate::extract_tables()) as the entry point for extracting tables from SQL.
4
5use core::fmt;
6use std::ops::ControlFlow;
7
8use crate::error::Error;
9use crate::helper;
10use sqlparser::ast::{Ident, ObjectName, Statement, TableFactor, TableWithJoins, Visit, Visitor};
11use sqlparser::dialect::Dialect;
12use sqlparser::parser::Parser;
13
14/// Convenience function to extract tables from SQL.
15///
16/// ## Example
17///
18/// ```rust
19/// use sql_insight::sqlparser::dialect::GenericDialect;
20///
21/// let dialect = GenericDialect {};
22/// let sql = "SELECT a FROM t1 INNER JOIN t2 ON t1.id = t2.id";
23/// let result = sql_insight::extract_tables(&dialect, sql).unwrap();
24/// println!("{:#?}", result);
25/// assert_eq!(result[0].as_ref().unwrap().to_string(), "t1, t2");
26/// ```
27pub fn extract_tables(
28    dialect: &dyn Dialect,
29    sql: &str,
30) -> Result<Vec<Result<Tables, Error>>, Error> {
31    TableExtractor::extract(dialect, sql)
32}
33
34/// [`TableReference`] represents a qualified table with alias.
35/// In this crate, this is the canonical representation of a table.
36/// Tables found during analyzing an AST are stored as `TableReference`.
37#[derive(Clone, Debug, PartialEq, Eq, Hash)]
38pub struct TableReference {
39    pub catalog: Option<Ident>,
40    pub schema: Option<Ident>,
41    pub name: Ident,
42    pub alias: Option<Ident>,
43}
44
45impl TableReference {
46    pub fn has_alias(&self) -> bool {
47        self.alias.is_some()
48    }
49    pub fn has_qualifiers(&self) -> bool {
50        self.catalog.is_some() || self.schema.is_some()
51    }
52}
53
54impl fmt::Display for TableReference {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        let mut parts = Vec::new();
57        if let Some(catalog) = &self.catalog {
58            parts.push(catalog.to_string());
59        }
60        if let Some(schema) = &self.schema {
61            parts.push(schema.to_string());
62        }
63        parts.push(self.name.to_string());
64        let table = parts.join(".");
65        if let Some(alias) = &self.alias {
66            write!(f, "{} AS {}", table, alias)
67        } else {
68            write!(f, "{}", table)
69        }
70    }
71}
72
73impl TryFrom<&TableFactor> for TableReference {
74    type Error = Error;
75
76    fn try_from(table: &TableFactor) -> Result<Self, Self::Error> {
77        match table {
78            TableFactor::Table { name, alias, .. } => match name.0.len() {
79                0 => unreachable!("Parser should not allow empty identifiers"),
80                1 => Ok(TableReference {
81                    catalog: None,
82                    schema: None,
83                    name: name.0[0].clone(),
84                    alias: alias.as_ref().map(|a| a.name.clone()),
85                }),
86                2 => Ok(TableReference {
87                    catalog: None,
88                    schema: Some(name.0[0].clone()),
89                    name: name.0[1].clone(),
90                    alias: alias.as_ref().map(|a| a.name.clone()),
91                }),
92                3 => Ok(TableReference {
93                    catalog: Some(name.0[0].clone()),
94                    schema: Some(name.0[1].clone()),
95                    name: name.0[2].clone(),
96                    alias: alias.as_ref().map(|a| a.name.clone()),
97                }),
98                _ => Err(Error::AnalysisError(
99                    "Too many identifiers provided".to_string(),
100                )),
101            },
102            _ => unreachable!("TableFactor::Table expected"),
103        }
104    }
105}
106
107impl TryFrom<&ObjectName> for TableReference {
108    type Error = Error;
109
110    fn try_from(obj_name: &ObjectName) -> Result<Self, Self::Error> {
111        match obj_name.0.len() {
112            0 => unreachable!("Parser should not allow empty identifiers"),
113            1 => Ok(TableReference {
114                catalog: None,
115                schema: None,
116                name: obj_name.0[0].clone(),
117                alias: None,
118            }),
119            2 => Ok(TableReference {
120                catalog: None,
121                schema: Some(obj_name.0[0].clone()),
122                name: obj_name.0[1].clone(),
123                alias: None,
124            }),
125            3 => Ok(TableReference {
126                catalog: Some(obj_name.0[0].clone()),
127                schema: Some(obj_name.0[1].clone()),
128                name: obj_name.0[2].clone(),
129                alias: None,
130            }),
131            _ => Err(Error::AnalysisError(
132                "Too many identifiers provided".to_string(),
133            )),
134        }
135    }
136}
137
138/// [`Tables`] represents a list of [`TableReference`] that found in SQL.
139#[derive(Debug, PartialEq)]
140pub struct Tables(pub Vec<TableReference>);
141
142impl fmt::Display for Tables {
143    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144        let tables = self
145            .0
146            .iter()
147            .map(|t| t.to_string())
148            .collect::<Vec<String>>()
149            .join(", ");
150        write!(f, "{}", tables)
151    }
152}
153
154/// A visitor to extract tables from SQL.
155#[derive(Default, Debug)]
156pub struct TableExtractor {
157    // All tables found in the SQL including aliases, must be resolved to original tables.
158    all_tables: Vec<TableReference>,
159    // Original tables found in the SQL, used to resolve aliases.
160    original_tables: Vec<TableReference>,
161    // Flag to indicate if the current relation is part of a `TableFactor::Table`
162    relation_of_table: bool,
163}
164
165impl Visitor for TableExtractor {
166    type Break = Error;
167
168    fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
169        // Skip if relation is part of a TableFactor::Table
170        if self.relation_of_table {
171            self.relation_of_table = false;
172            return ControlFlow::Continue(());
173        }
174        match TableReference::try_from(relation) {
175            Ok(table) => {
176                self.all_tables.push(table.clone());
177                self.original_tables.push(table)
178            }
179            Err(e) => return ControlFlow::Break(e),
180        }
181        ControlFlow::Continue(())
182    }
183
184    fn pre_visit_table_factor(&mut self, table_factor: &TableFactor) -> ControlFlow<Self::Break> {
185        if let TableFactor::Table { .. } = table_factor {
186            self.relation_of_table = true;
187            match TableReference::try_from(table_factor) {
188                Ok(table) => {
189                    self.all_tables.push(table.clone());
190                    self.original_tables.push(table)
191                }
192                Err(e) => return ControlFlow::Break(e),
193            }
194        }
195        ControlFlow::Continue(())
196    }
197
198    fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
199        if let Statement::Delete { tables, .. } = statement {
200            // tables of delete statement are not visited by `pre_visit_table_factor` nor `pre_visit_relation`.
201            for table in tables {
202                match TableReference::try_from(table) {
203                    Ok(table) => self.all_tables.push(table),
204                    Err(e) => return ControlFlow::Break(e),
205                }
206            }
207        }
208        ControlFlow::Continue(())
209    }
210}
211
212impl TableExtractor {
213    /// Extract tables from SQL.
214    pub fn extract(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Result<Tables, Error>>, Error> {
215        let statements = Parser::parse_sql(dialect, sql)?;
216        let results = statements
217            .iter()
218            .map(Self::extract_from_statement)
219            .collect::<Vec<Result<Tables, Error>>>();
220        Ok(results)
221    }
222
223    pub fn extract_from_statement(statement: &Statement) -> Result<Tables, Error> {
224        let mut visitor = TableExtractor::default();
225        match statement.visit(&mut visitor) {
226            ControlFlow::Break(e) => Err(e),
227            ControlFlow::Continue(()) => Ok(Tables(helper::resolve_aliased_tables(
228                visitor.all_tables,
229                visitor.original_tables,
230            ))),
231        }
232    }
233
234    // `Visit` trait object cannot be used since method `visit` has generic type parameters.
235    // Concrete type `TableWithJoins` is used instead.
236    pub fn extract_from_table_node(table: &TableWithJoins) -> Result<Tables, Error> {
237        let mut visitor = TableExtractor::default();
238        match table.visit(&mut visitor) {
239            ControlFlow::Break(e) => Err(e),
240            ControlFlow::Continue(()) => Ok(Tables(helper::resolve_aliased_tables(
241                visitor.all_tables,
242                visitor.original_tables,
243            ))),
244        }
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use crate::test_utils::all_dialects;
252
253    fn assert_table_extraction(
254        sql: &str,
255        expected: Vec<Result<Tables, Error>>,
256        dialects: Vec<Box<dyn Dialect>>,
257    ) {
258        for dialect in dialects {
259            let result = TableExtractor::extract(dialect.as_ref(), sql).unwrap();
260            assert_eq!(result, expected, "Failed for dialect: {dialect:?}")
261        }
262    }
263
264    #[test]
265    fn test_single_statement() {
266        let sql = "SELECT a FROM t1";
267        let expected = vec![Ok(Tables(vec![TableReference {
268            catalog: None,
269            schema: None,
270            name: "t1".into(),
271            alias: None,
272        }]))];
273        assert_table_extraction(sql, expected, all_dialects());
274    }
275
276    #[test]
277    fn test_multiple_statements() {
278        let sql = "SELECT a FROM t1; SELECT b FROM t2";
279        let expected = vec![
280            Ok(Tables(vec![TableReference {
281                catalog: None,
282                schema: None,
283                name: "t1".into(),
284                alias: None,
285            }])),
286            Ok(Tables(vec![TableReference {
287                catalog: None,
288                schema: None,
289                name: "t2".into(),
290                alias: None,
291            }])),
292        ];
293        assert_table_extraction(sql, expected, all_dialects());
294    }
295
296    #[test]
297    fn test_statement_with_alias() {
298        let sql = "SELECT a FROM t1 AS t1_alias";
299        let expected = vec![Ok(Tables(vec![TableReference {
300            catalog: None,
301            schema: None,
302            name: "t1".into(),
303            alias: Some("t1_alias".into()),
304        }]))];
305        assert_table_extraction(sql, expected, all_dialects());
306    }
307
308    #[test]
309    fn test_statement_with_schema_identifier() {
310        let sql = "SELECT a FROM schema.table; INSERT INTO schema.table (a) VALUES (1)";
311        let expected = vec![
312            Ok(Tables(vec![TableReference {
313                catalog: None,
314                schema: Some("schema".into()),
315                name: "table".into(),
316                alias: None,
317            }])),
318            Ok(Tables(vec![TableReference {
319                catalog: None,
320                schema: Some("schema".into()),
321                name: "table".into(),
322                alias: None,
323            }])),
324        ];
325        assert_table_extraction(sql, expected, all_dialects());
326    }
327
328    #[test]
329    fn test_statement_with_full_identifier() {
330        let sql =
331            "SELECT a FROM catalog.schema.table; INSERT INTO catalog.schema.table (a) VALUES (1)";
332        let expected = vec![
333            Ok(Tables(vec![TableReference {
334                catalog: Some("catalog".into()),
335                schema: Some("schema".into()),
336                name: "table".into(),
337                alias: None,
338            }])),
339            Ok(Tables(vec![TableReference {
340                catalog: Some("catalog".into()),
341                schema: Some("schema".into()),
342                name: "table".into(),
343                alias: None,
344            }])),
345        ];
346        assert_table_extraction(sql, expected, all_dialects());
347    }
348
349    #[test]
350    fn test_statement_with_table_identifier_and_alias() {
351        let sql = "SELECT a FROM catalog.schema.table AS table_alias";
352        let expected = vec![Ok(Tables(vec![TableReference {
353            catalog: Some("catalog".into()),
354            schema: Some("schema".into()),
355            name: "table".into(),
356            alias: Some("table_alias".into()),
357        }]))];
358        assert_table_extraction(sql, expected, all_dialects());
359    }
360
361    #[test]
362    fn test_statement_where_same_tables_appear_multiple_times() {
363        let sql = "SELECT a FROM t1 INNER JOIN t2 ON t1.id = t2.id WHERE b = ( SELECT c FROM t3 INNER JOIN t1 ON t3.id = t1.id )";
364        let expected = vec![Ok(Tables(vec![
365            TableReference {
366                catalog: None,
367                schema: None,
368                name: "t1".into(),
369                alias: None,
370            },
371            TableReference {
372                catalog: None,
373                schema: None,
374                name: "t2".into(),
375                alias: None,
376            },
377            TableReference {
378                catalog: None,
379                schema: None,
380                name: "t3".into(),
381                alias: None,
382            },
383            TableReference {
384                catalog: None,
385                schema: None,
386                name: "t1".into(),
387                alias: None,
388            },
389        ]))];
390        assert_table_extraction(sql, expected, all_dialects());
391    }
392
393    #[test]
394    fn test_statement_error_with_too_many_identifiers() {
395        let sql = "SELECT a FROM catalog.schema.table.extra";
396        let expected = vec![Err(Error::AnalysisError(
397            "Too many identifiers provided".to_string(),
398        ))];
399        assert_table_extraction(sql, expected, all_dialects());
400    }
401
402    mod delete_statement {
403        use super::*;
404
405        #[test]
406        fn test_delete_statement() {
407            let sql = "DELETE t1 FROM t1";
408            let expected = vec![Ok(Tables(vec![
409                TableReference {
410                    catalog: None,
411                    schema: None,
412                    name: "t1".into(),
413                    alias: None,
414                },
415                TableReference {
416                    catalog: None,
417                    schema: None,
418                    name: "t1".into(),
419                    alias: None,
420                },
421            ]))];
422            assert_table_extraction(sql, expected, all_dialects());
423        }
424
425        #[test]
426        fn test_delete_statement_with_aliases() {
427            let sql = "DELETE t1_alias FROM t1 AS t1_alias JOIN t2 AS t2_alias ON t1_alias.a = t2_alias.a WHERE t2_alias.b = 1";
428            let expected = vec![Ok(Tables(vec![
429                TableReference {
430                    catalog: None,
431                    schema: None,
432                    name: "t1".into(),
433                    alias: Some("t1_alias".into()),
434                },
435                TableReference {
436                    catalog: None,
437                    schema: None,
438                    name: "t1".into(),
439                    alias: Some("t1_alias".into()),
440                },
441                TableReference {
442                    catalog: None,
443                    schema: None,
444                    name: "t2".into(),
445                    alias: Some("t2_alias".into()),
446                },
447            ]))];
448            assert_table_extraction(sql, expected, all_dialects());
449        }
450
451        #[test]
452        fn test_delete_multiple_tables_with_join() {
453            let sql =
454                "DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.a = t2.a AND t2.a = t3.a";
455            let expected = vec![Ok(Tables(vec![
456                TableReference {
457                    catalog: None,
458                    schema: None,
459                    name: "t1".into(),
460                    alias: None,
461                },
462                TableReference {
463                    catalog: None,
464                    schema: None,
465                    name: "t2".into(),
466                    alias: None,
467                },
468                TableReference {
469                    catalog: None,
470                    schema: None,
471                    name: "t1".into(),
472                    alias: None,
473                },
474                TableReference {
475                    catalog: None,
476                    schema: None,
477                    name: "t2".into(),
478                    alias: None,
479                },
480                TableReference {
481                    catalog: None,
482                    schema: None,
483                    name: "t3".into(),
484                    alias: None,
485                },
486            ]))];
487            assert_table_extraction(sql, expected, all_dialects());
488        }
489
490        #[test]
491        fn test_delete_from_statement() {
492            let sql = "DELETE FROM t1";
493            let expected = vec![Ok(Tables(vec![TableReference {
494                catalog: None,
495                schema: None,
496                name: "t1".into(),
497                alias: None,
498            }]))];
499            assert_table_extraction(sql, expected, all_dialects());
500        }
501
502        #[test]
503        fn test_delete_from_statement_with_alias() {
504            let sql = "DELETE FROM t1_alias, t2_alias USING t1 AS t1_alias INNER JOIN t2 AS t2_alias INNER JOIN t3";
505            let expected = vec![Ok(Tables(vec![
506                TableReference {
507                    catalog: None,
508                    schema: None,
509                    name: "t1".into(),
510                    alias: Some("t1_alias".into()),
511                },
512                TableReference {
513                    catalog: None,
514                    schema: None,
515                    name: "t2".into(),
516                    alias: Some("t2_alias".into()),
517                },
518                TableReference {
519                    catalog: None,
520                    schema: None,
521                    name: "t1".into(),
522                    alias: Some("t1_alias".into()),
523                },
524                TableReference {
525                    catalog: None,
526                    schema: None,
527                    name: "t2".into(),
528                    alias: Some("t2_alias".into()),
529                },
530                TableReference {
531                    catalog: None,
532                    schema: None,
533                    name: "t3".into(),
534                    alias: None,
535                },
536            ]))];
537            assert_table_extraction(sql, expected, all_dialects());
538        }
539    }
540
541    mod insert_statement {
542        use super::*;
543
544        #[test]
545        fn test_insert_statement() {
546            let sql = "INSERT INTO t1 (a, b) VALUES (1, 2)";
547            let expected = vec![Ok(Tables(vec![TableReference {
548                catalog: None,
549                schema: None,
550                name: "t1".into(),
551                alias: None,
552            }]))];
553            assert_table_extraction(sql, expected, all_dialects());
554        }
555
556        #[test]
557        fn test_insert_select_statement() {
558            let sql = "INSERT INTO t1 SELECT * FROM t2";
559            let expected = vec![Ok(Tables(vec![
560                TableReference {
561                    catalog: None,
562                    schema: None,
563                    name: "t1".into(),
564                    alias: None,
565                },
566                TableReference {
567                    catalog: None,
568                    schema: None,
569                    name: "t2".into(),
570                    alias: None,
571                },
572            ]))];
573            assert_table_extraction(sql, expected, all_dialects());
574        }
575    }
576
577    mod update_statement {
578        use super::*;
579
580        #[test]
581        fn test_update_statement() {
582            let sql = "UPDATE t1 SET a = 1";
583            let expected = vec![Ok(Tables(vec![TableReference {
584                catalog: None,
585                schema: None,
586                name: "t1".into(),
587                alias: None,
588            }]))];
589            assert_table_extraction(sql, expected, all_dialects());
590        }
591
592        #[test]
593        fn test_update_statement_with_alias() {
594            let sql = "UPDATE t1 AS t1_alias INNER JOIN t2 ON t1_alias.a = t2.a SET t1_alias.b = t2.b WHERE t2.c = (SELECT c FROM t3)";
595            let expected = vec![Ok(Tables(vec![
596                TableReference {
597                    catalog: None,
598                    schema: None,
599                    name: "t1".into(),
600                    alias: Some("t1_alias".into()),
601                },
602                TableReference {
603                    catalog: None,
604                    schema: None,
605                    name: "t2".into(),
606                    alias: None,
607                },
608                TableReference {
609                    catalog: None,
610                    schema: None,
611                    name: "t3".into(),
612                    alias: None,
613                },
614            ]))];
615            assert_table_extraction(sql, expected, all_dialects());
616        }
617    }
618
619    #[test]
620    fn test_merge_statement() {
621        let sql = "MERGE INTO t1 USING t2 ON t1.a = t2.a \
622                         WHEN MATCHED THEN UPDATE SET t1.b = t2.b \
623                         WHEN NOT MATCHED THEN INSERT (a, b) VALUES (t2.a, t2.b)";
624        let expected = vec![Ok(Tables(vec![
625            TableReference {
626                catalog: None,
627                schema: None,
628                name: "t1".into(),
629                alias: None,
630            },
631            TableReference {
632                catalog: None,
633                schema: None,
634                name: "t2".into(),
635                alias: None,
636            },
637        ]))];
638        assert_table_extraction(sql, expected, all_dialects());
639    }
640
641    #[test]
642    fn test_merge_statement_with_alias() {
643        let sql = "MERGE INTO t1 AS t1_alias USING (SELECT a, b FROM t2) AS t2_alias(a, b) ON t1_alias.a = t2_alias.a \
644                         WHEN MATCHED THEN UPDATE SET t1_alias.b = t2_alias.b \
645                         WHEN NOT MATCHED THEN INSERT (a, b) VALUES (t2_alias.a, t2_alias.b)";
646        let expected = vec![Ok(Tables(vec![
647            TableReference {
648                catalog: None,
649                schema: None,
650                name: "t1".into(),
651                alias: Some("t1_alias".into()),
652            },
653            TableReference {
654                catalog: None,
655                schema: None,
656                name: "t2".into(),
657                alias: None,
658            },
659        ]))];
660        assert_table_extraction(sql, expected, all_dialects());
661    }
662
663    #[test]
664    fn test_create_table_statement() {
665        let sql = "CREATE TABLE t1 (a INT)";
666        let expected = vec![Ok(Tables(vec![TableReference {
667            catalog: None,
668            schema: None,
669            name: "t1".into(),
670            alias: None,
671        }]))];
672        assert_table_extraction(sql, expected, all_dialects());
673    }
674
675    #[test]
676    fn test_alters_table_statement() {
677        let sql = "ALTER TABLE t1 ADD COLUMN a INT";
678        let expected = vec![Ok(Tables(vec![TableReference {
679            catalog: None,
680            schema: None,
681            name: "t1".into(),
682            alias: None,
683        }]))];
684        assert_table_extraction(sql, expected, all_dialects());
685    }
686}