1use core::fmt;
6use std::ops::ControlFlow;
7
8use crate::error::Error;
9use crate::helper;
10use sqlparser::ast::{Ident, ObjectName, Statement, TableFactor, TableWithJoins, Visit, Visitor};
11use sqlparser::dialect::Dialect;
12use sqlparser::parser::Parser;
13
14pub fn extract_tables(
28 dialect: &dyn Dialect,
29 sql: &str,
30) -> Result<Vec<Result<Tables, Error>>, Error> {
31 TableExtractor::extract(dialect, sql)
32}
33
34#[derive(Clone, Debug, PartialEq, Eq, Hash)]
38pub struct TableReference {
39 pub catalog: Option<Ident>,
40 pub schema: Option<Ident>,
41 pub name: Ident,
42 pub alias: Option<Ident>,
43}
44
45impl TableReference {
46 pub fn has_alias(&self) -> bool {
47 self.alias.is_some()
48 }
49 pub fn has_qualifiers(&self) -> bool {
50 self.catalog.is_some() || self.schema.is_some()
51 }
52}
53
54impl fmt::Display for TableReference {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 let mut parts = Vec::new();
57 if let Some(catalog) = &self.catalog {
58 parts.push(catalog.to_string());
59 }
60 if let Some(schema) = &self.schema {
61 parts.push(schema.to_string());
62 }
63 parts.push(self.name.to_string());
64 let table = parts.join(".");
65 if let Some(alias) = &self.alias {
66 write!(f, "{} AS {}", table, alias)
67 } else {
68 write!(f, "{}", table)
69 }
70 }
71}
72
73impl TryFrom<&TableFactor> for TableReference {
74 type Error = Error;
75
76 fn try_from(table: &TableFactor) -> Result<Self, Self::Error> {
77 match table {
78 TableFactor::Table { name, alias, .. } => match name.0.len() {
79 0 => unreachable!("Parser should not allow empty identifiers"),
80 1 => Ok(TableReference {
81 catalog: None,
82 schema: None,
83 name: name.0[0].clone(),
84 alias: alias.as_ref().map(|a| a.name.clone()),
85 }),
86 2 => Ok(TableReference {
87 catalog: None,
88 schema: Some(name.0[0].clone()),
89 name: name.0[1].clone(),
90 alias: alias.as_ref().map(|a| a.name.clone()),
91 }),
92 3 => Ok(TableReference {
93 catalog: Some(name.0[0].clone()),
94 schema: Some(name.0[1].clone()),
95 name: name.0[2].clone(),
96 alias: alias.as_ref().map(|a| a.name.clone()),
97 }),
98 _ => Err(Error::AnalysisError(
99 "Too many identifiers provided".to_string(),
100 )),
101 },
102 _ => unreachable!("TableFactor::Table expected"),
103 }
104 }
105}
106
107impl TryFrom<&ObjectName> for TableReference {
108 type Error = Error;
109
110 fn try_from(obj_name: &ObjectName) -> Result<Self, Self::Error> {
111 match obj_name.0.len() {
112 0 => unreachable!("Parser should not allow empty identifiers"),
113 1 => Ok(TableReference {
114 catalog: None,
115 schema: None,
116 name: obj_name.0[0].clone(),
117 alias: None,
118 }),
119 2 => Ok(TableReference {
120 catalog: None,
121 schema: Some(obj_name.0[0].clone()),
122 name: obj_name.0[1].clone(),
123 alias: None,
124 }),
125 3 => Ok(TableReference {
126 catalog: Some(obj_name.0[0].clone()),
127 schema: Some(obj_name.0[1].clone()),
128 name: obj_name.0[2].clone(),
129 alias: None,
130 }),
131 _ => Err(Error::AnalysisError(
132 "Too many identifiers provided".to_string(),
133 )),
134 }
135 }
136}
137
138#[derive(Debug, PartialEq)]
140pub struct Tables(pub Vec<TableReference>);
141
142impl fmt::Display for Tables {
143 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144 let tables = self
145 .0
146 .iter()
147 .map(|t| t.to_string())
148 .collect::<Vec<String>>()
149 .join(", ");
150 write!(f, "{}", tables)
151 }
152}
153
154#[derive(Default, Debug)]
156pub struct TableExtractor {
157 all_tables: Vec<TableReference>,
159 original_tables: Vec<TableReference>,
161 relation_of_table: bool,
163}
164
165impl Visitor for TableExtractor {
166 type Break = Error;
167
168 fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
169 if self.relation_of_table {
171 self.relation_of_table = false;
172 return ControlFlow::Continue(());
173 }
174 match TableReference::try_from(relation) {
175 Ok(table) => {
176 self.all_tables.push(table.clone());
177 self.original_tables.push(table)
178 }
179 Err(e) => return ControlFlow::Break(e),
180 }
181 ControlFlow::Continue(())
182 }
183
184 fn pre_visit_table_factor(&mut self, table_factor: &TableFactor) -> ControlFlow<Self::Break> {
185 if let TableFactor::Table { .. } = table_factor {
186 self.relation_of_table = true;
187 match TableReference::try_from(table_factor) {
188 Ok(table) => {
189 self.all_tables.push(table.clone());
190 self.original_tables.push(table)
191 }
192 Err(e) => return ControlFlow::Break(e),
193 }
194 }
195 ControlFlow::Continue(())
196 }
197
198 fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
199 if let Statement::Delete { tables, .. } = statement {
200 for table in tables {
202 match TableReference::try_from(table) {
203 Ok(table) => self.all_tables.push(table),
204 Err(e) => return ControlFlow::Break(e),
205 }
206 }
207 }
208 ControlFlow::Continue(())
209 }
210}
211
212impl TableExtractor {
213 pub fn extract(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Result<Tables, Error>>, Error> {
215 let statements = Parser::parse_sql(dialect, sql)?;
216 let results = statements
217 .iter()
218 .map(Self::extract_from_statement)
219 .collect::<Vec<Result<Tables, Error>>>();
220 Ok(results)
221 }
222
223 pub fn extract_from_statement(statement: &Statement) -> Result<Tables, Error> {
224 let mut visitor = TableExtractor::default();
225 match statement.visit(&mut visitor) {
226 ControlFlow::Break(e) => Err(e),
227 ControlFlow::Continue(()) => Ok(Tables(helper::resolve_aliased_tables(
228 visitor.all_tables,
229 visitor.original_tables,
230 ))),
231 }
232 }
233
234 pub fn extract_from_table_node(table: &TableWithJoins) -> Result<Tables, Error> {
237 let mut visitor = TableExtractor::default();
238 match table.visit(&mut visitor) {
239 ControlFlow::Break(e) => Err(e),
240 ControlFlow::Continue(()) => Ok(Tables(helper::resolve_aliased_tables(
241 visitor.all_tables,
242 visitor.original_tables,
243 ))),
244 }
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use crate::test_utils::all_dialects;
252
253 fn assert_table_extraction(
254 sql: &str,
255 expected: Vec<Result<Tables, Error>>,
256 dialects: Vec<Box<dyn Dialect>>,
257 ) {
258 for dialect in dialects {
259 let result = TableExtractor::extract(dialect.as_ref(), sql).unwrap();
260 assert_eq!(result, expected, "Failed for dialect: {dialect:?}")
261 }
262 }
263
264 #[test]
265 fn test_single_statement() {
266 let sql = "SELECT a FROM t1";
267 let expected = vec![Ok(Tables(vec![TableReference {
268 catalog: None,
269 schema: None,
270 name: "t1".into(),
271 alias: None,
272 }]))];
273 assert_table_extraction(sql, expected, all_dialects());
274 }
275
276 #[test]
277 fn test_multiple_statements() {
278 let sql = "SELECT a FROM t1; SELECT b FROM t2";
279 let expected = vec![
280 Ok(Tables(vec![TableReference {
281 catalog: None,
282 schema: None,
283 name: "t1".into(),
284 alias: None,
285 }])),
286 Ok(Tables(vec![TableReference {
287 catalog: None,
288 schema: None,
289 name: "t2".into(),
290 alias: None,
291 }])),
292 ];
293 assert_table_extraction(sql, expected, all_dialects());
294 }
295
296 #[test]
297 fn test_statement_with_alias() {
298 let sql = "SELECT a FROM t1 AS t1_alias";
299 let expected = vec![Ok(Tables(vec![TableReference {
300 catalog: None,
301 schema: None,
302 name: "t1".into(),
303 alias: Some("t1_alias".into()),
304 }]))];
305 assert_table_extraction(sql, expected, all_dialects());
306 }
307
308 #[test]
309 fn test_statement_with_schema_identifier() {
310 let sql = "SELECT a FROM schema.table; INSERT INTO schema.table (a) VALUES (1)";
311 let expected = vec![
312 Ok(Tables(vec![TableReference {
313 catalog: None,
314 schema: Some("schema".into()),
315 name: "table".into(),
316 alias: None,
317 }])),
318 Ok(Tables(vec![TableReference {
319 catalog: None,
320 schema: Some("schema".into()),
321 name: "table".into(),
322 alias: None,
323 }])),
324 ];
325 assert_table_extraction(sql, expected, all_dialects());
326 }
327
328 #[test]
329 fn test_statement_with_full_identifier() {
330 let sql =
331 "SELECT a FROM catalog.schema.table; INSERT INTO catalog.schema.table (a) VALUES (1)";
332 let expected = vec![
333 Ok(Tables(vec![TableReference {
334 catalog: Some("catalog".into()),
335 schema: Some("schema".into()),
336 name: "table".into(),
337 alias: None,
338 }])),
339 Ok(Tables(vec![TableReference {
340 catalog: Some("catalog".into()),
341 schema: Some("schema".into()),
342 name: "table".into(),
343 alias: None,
344 }])),
345 ];
346 assert_table_extraction(sql, expected, all_dialects());
347 }
348
349 #[test]
350 fn test_statement_with_table_identifier_and_alias() {
351 let sql = "SELECT a FROM catalog.schema.table AS table_alias";
352 let expected = vec![Ok(Tables(vec![TableReference {
353 catalog: Some("catalog".into()),
354 schema: Some("schema".into()),
355 name: "table".into(),
356 alias: Some("table_alias".into()),
357 }]))];
358 assert_table_extraction(sql, expected, all_dialects());
359 }
360
361 #[test]
362 fn test_statement_where_same_tables_appear_multiple_times() {
363 let sql = "SELECT a FROM t1 INNER JOIN t2 ON t1.id = t2.id WHERE b = ( SELECT c FROM t3 INNER JOIN t1 ON t3.id = t1.id )";
364 let expected = vec![Ok(Tables(vec![
365 TableReference {
366 catalog: None,
367 schema: None,
368 name: "t1".into(),
369 alias: None,
370 },
371 TableReference {
372 catalog: None,
373 schema: None,
374 name: "t2".into(),
375 alias: None,
376 },
377 TableReference {
378 catalog: None,
379 schema: None,
380 name: "t3".into(),
381 alias: None,
382 },
383 TableReference {
384 catalog: None,
385 schema: None,
386 name: "t1".into(),
387 alias: None,
388 },
389 ]))];
390 assert_table_extraction(sql, expected, all_dialects());
391 }
392
393 #[test]
394 fn test_statement_error_with_too_many_identifiers() {
395 let sql = "SELECT a FROM catalog.schema.table.extra";
396 let expected = vec![Err(Error::AnalysisError(
397 "Too many identifiers provided".to_string(),
398 ))];
399 assert_table_extraction(sql, expected, all_dialects());
400 }
401
402 mod delete_statement {
403 use super::*;
404
405 #[test]
406 fn test_delete_statement() {
407 let sql = "DELETE t1 FROM t1";
408 let expected = vec![Ok(Tables(vec![
409 TableReference {
410 catalog: None,
411 schema: None,
412 name: "t1".into(),
413 alias: None,
414 },
415 TableReference {
416 catalog: None,
417 schema: None,
418 name: "t1".into(),
419 alias: None,
420 },
421 ]))];
422 assert_table_extraction(sql, expected, all_dialects());
423 }
424
425 #[test]
426 fn test_delete_statement_with_aliases() {
427 let sql = "DELETE t1_alias FROM t1 AS t1_alias JOIN t2 AS t2_alias ON t1_alias.a = t2_alias.a WHERE t2_alias.b = 1";
428 let expected = vec![Ok(Tables(vec![
429 TableReference {
430 catalog: None,
431 schema: None,
432 name: "t1".into(),
433 alias: Some("t1_alias".into()),
434 },
435 TableReference {
436 catalog: None,
437 schema: None,
438 name: "t1".into(),
439 alias: Some("t1_alias".into()),
440 },
441 TableReference {
442 catalog: None,
443 schema: None,
444 name: "t2".into(),
445 alias: Some("t2_alias".into()),
446 },
447 ]))];
448 assert_table_extraction(sql, expected, all_dialects());
449 }
450
451 #[test]
452 fn test_delete_multiple_tables_with_join() {
453 let sql =
454 "DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.a = t2.a AND t2.a = t3.a";
455 let expected = vec![Ok(Tables(vec![
456 TableReference {
457 catalog: None,
458 schema: None,
459 name: "t1".into(),
460 alias: None,
461 },
462 TableReference {
463 catalog: None,
464 schema: None,
465 name: "t2".into(),
466 alias: None,
467 },
468 TableReference {
469 catalog: None,
470 schema: None,
471 name: "t1".into(),
472 alias: None,
473 },
474 TableReference {
475 catalog: None,
476 schema: None,
477 name: "t2".into(),
478 alias: None,
479 },
480 TableReference {
481 catalog: None,
482 schema: None,
483 name: "t3".into(),
484 alias: None,
485 },
486 ]))];
487 assert_table_extraction(sql, expected, all_dialects());
488 }
489
490 #[test]
491 fn test_delete_from_statement() {
492 let sql = "DELETE FROM t1";
493 let expected = vec![Ok(Tables(vec![TableReference {
494 catalog: None,
495 schema: None,
496 name: "t1".into(),
497 alias: None,
498 }]))];
499 assert_table_extraction(sql, expected, all_dialects());
500 }
501
502 #[test]
503 fn test_delete_from_statement_with_alias() {
504 let sql = "DELETE FROM t1_alias, t2_alias USING t1 AS t1_alias INNER JOIN t2 AS t2_alias INNER JOIN t3";
505 let expected = vec![Ok(Tables(vec![
506 TableReference {
507 catalog: None,
508 schema: None,
509 name: "t1".into(),
510 alias: Some("t1_alias".into()),
511 },
512 TableReference {
513 catalog: None,
514 schema: None,
515 name: "t2".into(),
516 alias: Some("t2_alias".into()),
517 },
518 TableReference {
519 catalog: None,
520 schema: None,
521 name: "t1".into(),
522 alias: Some("t1_alias".into()),
523 },
524 TableReference {
525 catalog: None,
526 schema: None,
527 name: "t2".into(),
528 alias: Some("t2_alias".into()),
529 },
530 TableReference {
531 catalog: None,
532 schema: None,
533 name: "t3".into(),
534 alias: None,
535 },
536 ]))];
537 assert_table_extraction(sql, expected, all_dialects());
538 }
539 }
540
541 mod insert_statement {
542 use super::*;
543
544 #[test]
545 fn test_insert_statement() {
546 let sql = "INSERT INTO t1 (a, b) VALUES (1, 2)";
547 let expected = vec![Ok(Tables(vec![TableReference {
548 catalog: None,
549 schema: None,
550 name: "t1".into(),
551 alias: None,
552 }]))];
553 assert_table_extraction(sql, expected, all_dialects());
554 }
555
556 #[test]
557 fn test_insert_select_statement() {
558 let sql = "INSERT INTO t1 SELECT * FROM t2";
559 let expected = vec![Ok(Tables(vec![
560 TableReference {
561 catalog: None,
562 schema: None,
563 name: "t1".into(),
564 alias: None,
565 },
566 TableReference {
567 catalog: None,
568 schema: None,
569 name: "t2".into(),
570 alias: None,
571 },
572 ]))];
573 assert_table_extraction(sql, expected, all_dialects());
574 }
575 }
576
577 mod update_statement {
578 use super::*;
579
580 #[test]
581 fn test_update_statement() {
582 let sql = "UPDATE t1 SET a = 1";
583 let expected = vec![Ok(Tables(vec![TableReference {
584 catalog: None,
585 schema: None,
586 name: "t1".into(),
587 alias: None,
588 }]))];
589 assert_table_extraction(sql, expected, all_dialects());
590 }
591
592 #[test]
593 fn test_update_statement_with_alias() {
594 let sql = "UPDATE t1 AS t1_alias INNER JOIN t2 ON t1_alias.a = t2.a SET t1_alias.b = t2.b WHERE t2.c = (SELECT c FROM t3)";
595 let expected = vec![Ok(Tables(vec![
596 TableReference {
597 catalog: None,
598 schema: None,
599 name: "t1".into(),
600 alias: Some("t1_alias".into()),
601 },
602 TableReference {
603 catalog: None,
604 schema: None,
605 name: "t2".into(),
606 alias: None,
607 },
608 TableReference {
609 catalog: None,
610 schema: None,
611 name: "t3".into(),
612 alias: None,
613 },
614 ]))];
615 assert_table_extraction(sql, expected, all_dialects());
616 }
617 }
618
619 #[test]
620 fn test_merge_statement() {
621 let sql = "MERGE INTO t1 USING t2 ON t1.a = t2.a \
622 WHEN MATCHED THEN UPDATE SET t1.b = t2.b \
623 WHEN NOT MATCHED THEN INSERT (a, b) VALUES (t2.a, t2.b)";
624 let expected = vec![Ok(Tables(vec![
625 TableReference {
626 catalog: None,
627 schema: None,
628 name: "t1".into(),
629 alias: None,
630 },
631 TableReference {
632 catalog: None,
633 schema: None,
634 name: "t2".into(),
635 alias: None,
636 },
637 ]))];
638 assert_table_extraction(sql, expected, all_dialects());
639 }
640
641 #[test]
642 fn test_merge_statement_with_alias() {
643 let sql = "MERGE INTO t1 AS t1_alias USING (SELECT a, b FROM t2) AS t2_alias(a, b) ON t1_alias.a = t2_alias.a \
644 WHEN MATCHED THEN UPDATE SET t1_alias.b = t2_alias.b \
645 WHEN NOT MATCHED THEN INSERT (a, b) VALUES (t2_alias.a, t2_alias.b)";
646 let expected = vec![Ok(Tables(vec![
647 TableReference {
648 catalog: None,
649 schema: None,
650 name: "t1".into(),
651 alias: Some("t1_alias".into()),
652 },
653 TableReference {
654 catalog: None,
655 schema: None,
656 name: "t2".into(),
657 alias: None,
658 },
659 ]))];
660 assert_table_extraction(sql, expected, all_dialects());
661 }
662
663 #[test]
664 fn test_create_table_statement() {
665 let sql = "CREATE TABLE t1 (a INT)";
666 let expected = vec![Ok(Tables(vec![TableReference {
667 catalog: None,
668 schema: None,
669 name: "t1".into(),
670 alias: None,
671 }]))];
672 assert_table_extraction(sql, expected, all_dialects());
673 }
674
675 #[test]
676 fn test_alters_table_statement() {
677 let sql = "ALTER TABLE t1 ADD COLUMN a INT";
678 let expected = vec![Ok(Tables(vec![TableReference {
679 catalog: None,
680 schema: None,
681 name: "t1".into(),
682 alias: None,
683 }]))];
684 assert_table_extraction(sql, expected, all_dialects());
685 }
686}