1use std::collections::HashMap;
11
12use crate::ast::*;
13use crate::dialects::Dialect;
14use crate::schema::{Schema, normalize_identifier};
15
16pub fn qualify_columns<S: Schema>(statement: Statement, schema: &S) -> Statement {
21 let dialect = schema.dialect();
22 match statement {
23 Statement::Select(sel) => {
24 let qualified = qualify_select(sel, schema, dialect, &HashMap::new());
25 Statement::Select(qualified)
26 }
27 Statement::SetOperation(mut set_op) => {
28 set_op.left = Box::new(qualify_columns(*set_op.left, schema));
29 set_op.right = Box::new(qualify_columns(*set_op.right, schema));
30 Statement::SetOperation(set_op)
31 }
32 other => other,
33 }
34}
35
36#[derive(Debug, Clone)]
38struct SourceColumns {
39 columns: Vec<String>,
41}
42
43fn resolve_source_columns<S: Schema>(
45 sel: &SelectStatement,
46 schema: &S,
47 dialect: Dialect,
48 cte_columns: &HashMap<String, Vec<String>>,
49) -> HashMap<String, SourceColumns> {
50 let mut source_map: HashMap<String, SourceColumns> = HashMap::new();
51
52 if let Some(from) = &sel.from {
54 collect_source_columns(&from.source, schema, dialect, cte_columns, &mut source_map);
55 }
56
57 for join in &sel.joins {
59 collect_source_columns(&join.table, schema, dialect, cte_columns, &mut source_map);
60 }
61
62 source_map
63}
64
65fn collect_source_columns<S: Schema>(
67 source: &TableSource,
68 schema: &S,
69 dialect: Dialect,
70 cte_columns: &HashMap<String, Vec<String>>,
71 source_map: &mut HashMap<String, SourceColumns>,
72) {
73 match source {
74 TableSource::Table(table_ref) => {
75 let key = table_ref
76 .alias
77 .as_deref()
78 .unwrap_or(&table_ref.name)
79 .to_string();
80 let norm_key = normalize_identifier(&key, dialect);
81
82 let norm_name = normalize_identifier(&table_ref.name, dialect);
84 if let Some(cols) = cte_columns.get(&norm_name) {
85 source_map.insert(
86 norm_key,
87 SourceColumns {
88 columns: cols.clone(),
89 },
90 );
91 return;
92 }
93
94 let path = build_table_path(table_ref, dialect);
96 let path_refs: Vec<&str> = path.iter().map(|s| s.as_str()).collect();
97
98 if let Ok(cols) = schema.column_names(&path_refs) {
99 source_map.insert(norm_key, SourceColumns { columns: cols });
100 }
101 }
102 TableSource::Subquery { query, alias, .. } => {
103 if let Some(alias) = alias {
104 let norm_alias = normalize_identifier(alias, dialect);
105 let cols = extract_output_columns(query, schema, dialect, cte_columns);
106 source_map.insert(norm_alias, SourceColumns { columns: cols });
107 }
108 }
109 TableSource::Lateral { source: inner } => {
110 collect_source_columns(inner, schema, dialect, cte_columns, source_map);
111 }
112 TableSource::Pivot { source, alias, .. } | TableSource::Unpivot { source, alias, .. } => {
113 collect_source_columns(source, schema, dialect, cte_columns, source_map);
114 if let Some(alias) = alias {
115 let norm_alias = normalize_identifier(alias, dialect);
116 source_map.insert(norm_alias, SourceColumns { columns: vec![] });
117 }
118 }
119 TableSource::Unnest { alias, .. } => {
120 if let Some(alias) = alias {
121 let norm_alias = normalize_identifier(alias, dialect);
122 source_map.insert(norm_alias, SourceColumns { columns: vec![] });
124 }
125 }
126 TableSource::TableFunction { alias, .. } => {
127 if let Some(alias) = alias {
128 let norm_alias = normalize_identifier(alias, dialect);
129 source_map.insert(norm_alias, SourceColumns { columns: vec![] });
130 }
131 }
132 }
133}
134
135fn build_table_path(table_ref: &TableRef, dialect: Dialect) -> Vec<String> {
137 let mut path = Vec::new();
138 if let Some(cat) = &table_ref.catalog {
139 path.push(normalize_identifier(cat, dialect));
140 }
141 if let Some(sch) = &table_ref.schema {
142 path.push(normalize_identifier(sch, dialect));
143 }
144 path.push(normalize_identifier(&table_ref.name, dialect));
145 path
146}
147
148fn extract_output_columns<S: Schema>(
150 stmt: &Statement,
151 schema: &S,
152 dialect: Dialect,
153 cte_columns: &HashMap<String, Vec<String>>,
154) -> Vec<String> {
155 match stmt {
156 Statement::Select(sel) => {
157 let inner_sources = resolve_source_columns(sel, schema, dialect, cte_columns);
158 let mut cols = Vec::new();
159 for item in &sel.columns {
160 match item {
161 SelectItem::Wildcard => {
162 for_each_source_ordered(sel, dialect, &inner_sources, |sc| {
164 cols.extend(sc.columns.iter().cloned());
165 });
166 }
167 SelectItem::QualifiedWildcard { table } => {
168 let norm_table = normalize_identifier(table, dialect);
169 if let Some(sc) = inner_sources.get(&norm_table) {
170 cols.extend(sc.columns.iter().cloned());
171 }
172 }
173 SelectItem::Expr { alias, expr, .. } => {
174 if let Some(alias) = alias {
175 cols.push(alias.clone());
176 } else {
177 cols.push(expr_output_name(expr));
178 }
179 }
180 }
181 }
182 cols
183 }
184 Statement::SetOperation(set_op) => {
185 extract_output_columns(&set_op.left, schema, dialect, cte_columns)
187 }
188 _ => vec![],
189 }
190}
191
192fn expr_output_name(expr: &Expr) -> String {
194 match expr {
195 Expr::Column { name, .. } => name.clone(),
196 Expr::Function { name, .. } => name.clone(),
197 Expr::TypedFunction { .. } => "_col".to_string(),
198 _ => "_col".to_string(),
199 }
200}
201
202fn for_each_source_ordered<F>(
204 sel: &SelectStatement,
205 dialect: Dialect,
206 source_map: &HashMap<String, SourceColumns>,
207 mut callback: F,
208) where
209 F: FnMut(&SourceColumns),
210{
211 if let Some(from) = &sel.from {
213 let key = source_key_for(&from.source, dialect);
214 if let Some(sc) = source_map.get(&key) {
215 callback(sc);
216 }
217 }
218 for join in &sel.joins {
220 let key = source_key_for(&join.table, dialect);
221 if let Some(sc) = source_map.get(&key) {
222 callback(sc);
223 }
224 }
225}
226
227fn source_key_for(source: &TableSource, dialect: Dialect) -> String {
229 match source {
230 TableSource::Table(tr) => {
231 let name = tr.alias.as_deref().unwrap_or(&tr.name);
232 normalize_identifier(name, dialect)
233 }
234 TableSource::Subquery { alias, .. } => alias
235 .as_deref()
236 .map(|a| normalize_identifier(a, dialect))
237 .unwrap_or_default(),
238 TableSource::Lateral { source } => source_key_for(source, dialect),
239 TableSource::Pivot { source, alias, .. } | TableSource::Unpivot { source, alias, .. } => {
240 if let Some(a) = alias {
241 normalize_identifier(a, dialect)
242 } else {
243 source_key_for(source, dialect)
244 }
245 }
246 TableSource::Unnest { alias, .. } | TableSource::TableFunction { alias, .. } => alias
247 .as_deref()
248 .map(|a| normalize_identifier(a, dialect))
249 .unwrap_or_default(),
250 }
251}
252
253fn qualify_select<S: Schema>(
255 mut sel: SelectStatement,
256 schema: &S,
257 dialect: Dialect,
258 outer_cte_columns: &HashMap<String, Vec<String>>,
259) -> SelectStatement {
260 let mut cte_columns = outer_cte_columns.clone();
262 for cte in &sel.ctes {
263 let cols = if !cte.columns.is_empty() {
264 cte.columns.clone()
266 } else {
267 extract_output_columns(&cte.query, schema, dialect, &cte_columns)
268 };
269 let norm_name = normalize_identifier(&cte.name, dialect);
270 cte_columns.insert(norm_name, cols);
271 }
272
273 sel.ctes = sel
275 .ctes
276 .into_iter()
277 .map(|mut cte| {
278 cte.query = Box::new(qualify_columns(*cte.query, schema));
279 cte
280 })
281 .collect();
282
283 if let Some(ref mut from) = sel.from {
285 qualify_table_source(&mut from.source, schema, dialect, &cte_columns);
286 }
287 for join in &mut sel.joins {
288 qualify_table_source(&mut join.table, schema, dialect, &cte_columns);
289 }
290
291 let source_map = resolve_source_columns(&sel, schema, dialect, &cte_columns);
293
294 let mut new_columns = Vec::new();
296 let old_columns = std::mem::take(&mut sel.columns);
297 for item in old_columns {
298 match item {
299 SelectItem::Wildcard => {
300 for_each_source_ordered(&sel, dialect, &source_map, |sc| {
302 for col_name in &sc.columns {
303 new_columns.push(SelectItem::Expr {
304 expr: Expr::Column {
305 table: None,
306 name: col_name.clone(),
307 quote_style: QuoteStyle::None,
308 table_quote_style: QuoteStyle::None,
309 },
310 alias: None,
311 alias_quote_style: QuoteStyle::None,
312 });
313 }
314 });
315 }
316 SelectItem::QualifiedWildcard { table } => {
317 let norm_table = normalize_identifier(&table, dialect);
318 if let Some(sc) = source_map.get(&norm_table) {
319 for col_name in &sc.columns {
320 new_columns.push(SelectItem::Expr {
321 expr: Expr::Column {
322 table: Some(table.clone()),
323 name: col_name.clone(),
324 quote_style: QuoteStyle::None,
325 table_quote_style: QuoteStyle::None,
326 },
327 alias: None,
328 alias_quote_style: QuoteStyle::None,
329 });
330 }
331 } else {
332 new_columns.push(SelectItem::QualifiedWildcard { table });
334 }
335 }
336 SelectItem::Expr {
337 expr,
338 alias,
339 alias_quote_style,
340 ..
341 } => {
342 let qualified_expr = qualify_expr(expr, &source_map, schema, dialect, &cte_columns);
343 new_columns.push(SelectItem::Expr {
344 expr: qualified_expr,
345 alias,
346 alias_quote_style,
347 });
348 }
349 }
350 }
351 sel.columns = new_columns;
352
353 if let Some(wh) = sel.where_clause {
355 sel.where_clause = Some(qualify_expr(wh, &source_map, schema, dialect, &cte_columns));
356 }
357 sel.group_by = sel
358 .group_by
359 .into_iter()
360 .map(|e| qualify_expr(e, &source_map, schema, dialect, &cte_columns))
361 .collect();
362 if let Some(having) = sel.having {
363 sel.having = Some(qualify_expr(
364 having,
365 &source_map,
366 schema,
367 dialect,
368 &cte_columns,
369 ));
370 }
371 sel.order_by = sel
372 .order_by
373 .into_iter()
374 .map(|mut item| {
375 item.expr = qualify_expr(item.expr, &source_map, schema, dialect, &cte_columns);
376 item
377 })
378 .collect();
379 if let Some(qualify) = sel.qualify {
380 sel.qualify = Some(qualify_expr(
381 qualify,
382 &source_map,
383 schema,
384 dialect,
385 &cte_columns,
386 ));
387 }
388
389 for join in &mut sel.joins {
391 if let Some(on) = join.on.take() {
392 join.on = Some(qualify_expr(on, &source_map, schema, dialect, &cte_columns));
393 }
394 }
395
396 sel
397}
398
399fn qualify_table_source<S: Schema>(
401 source: &mut TableSource,
402 schema: &S,
403 dialect: Dialect,
404 cte_columns: &HashMap<String, Vec<String>>,
405) {
406 match source {
407 TableSource::Subquery { query, .. } => {
408 *query = Box::new(qualify_columns_inner(
409 *query.clone(),
410 schema,
411 dialect,
412 cte_columns,
413 ));
414 }
415 TableSource::Lateral { source: inner } => {
416 qualify_table_source(inner, schema, dialect, cte_columns);
417 }
418 TableSource::Pivot { source, .. } | TableSource::Unpivot { source, .. } => {
419 qualify_table_source(source, schema, dialect, cte_columns);
420 }
421 _ => {}
422 }
423}
424
425fn qualify_columns_inner<S: Schema>(
427 statement: Statement,
428 schema: &S,
429 dialect: Dialect,
430 cte_columns: &HashMap<String, Vec<String>>,
431) -> Statement {
432 match statement {
433 Statement::Select(sel) => {
434 Statement::Select(qualify_select(sel, schema, dialect, cte_columns))
435 }
436 Statement::SetOperation(mut set_op) => {
437 set_op.left = Box::new(qualify_columns_inner(
438 *set_op.left,
439 schema,
440 dialect,
441 cte_columns,
442 ));
443 set_op.right = Box::new(qualify_columns_inner(
444 *set_op.right,
445 schema,
446 dialect,
447 cte_columns,
448 ));
449 Statement::SetOperation(set_op)
450 }
451 other => other,
452 }
453}
454
455fn qualify_expr<S: Schema>(
458 expr: Expr,
459 source_map: &HashMap<String, SourceColumns>,
460 schema: &S,
461 dialect: Dialect,
462 cte_columns: &HashMap<String, Vec<String>>,
463) -> Expr {
464 expr.transform(&|e| match e {
465 Expr::Column {
466 table: None,
467 name,
468 quote_style,
469 table_quote_style,
470 } => {
471 let norm_name = normalize_identifier(&name, dialect);
472 let resolved_source = resolve_column(&norm_name, source_map);
474 if let Some(source_name) = resolved_source {
475 Expr::Column {
476 table: Some(source_name),
477 name,
478 quote_style,
479 table_quote_style,
480 }
481 } else {
482 Expr::Column {
485 table: None,
486 name,
487 quote_style,
488 table_quote_style,
489 }
490 }
491 }
492 Expr::InSubquery {
494 expr,
495 subquery,
496 negated,
497 } => Expr::InSubquery {
498 expr,
499 subquery: Box::new(qualify_columns_inner(
500 *subquery,
501 schema,
502 dialect,
503 cte_columns,
504 )),
505 negated,
506 },
507 Expr::Subquery(stmt) => Expr::Subquery(Box::new(qualify_columns_inner(
508 *stmt,
509 schema,
510 dialect,
511 cte_columns,
512 ))),
513 Expr::Exists { subquery, negated } => Expr::Exists {
514 subquery: Box::new(qualify_columns_inner(
515 *subquery,
516 schema,
517 dialect,
518 cte_columns,
519 )),
520 negated,
521 },
522 other => other,
523 })
524}
525
526fn resolve_column(
530 norm_col_name: &str,
531 source_map: &HashMap<String, SourceColumns>,
532) -> Option<String> {
533 let mut matches: Vec<&str> = Vec::new();
534 for (source_name, sc) in source_map {
535 if sc
536 .columns
537 .iter()
538 .any(|c| c.eq_ignore_ascii_case(norm_col_name))
539 {
540 matches.push(source_name);
541 }
542 }
543 if matches.len() == 1 {
544 Some(matches[0].to_string())
545 } else {
546 None
547 }
548}
549
550#[cfg(test)]
555mod tests {
556 use super::*;
557 use crate::generator::generate;
558 use crate::parser::parse;
559 use crate::schema::MappingSchema;
560
561 fn make_schema() -> MappingSchema {
562 let mut schema = MappingSchema::new(Dialect::Ansi);
563 schema
564 .add_table(
565 &["users"],
566 vec![
567 ("id".to_string(), DataType::Int),
568 ("name".to_string(), DataType::Varchar(Some(255))),
569 ("email".to_string(), DataType::Text),
570 ],
571 )
572 .unwrap();
573 schema
574 .add_table(
575 &["orders"],
576 vec![
577 ("id".to_string(), DataType::Int),
578 ("user_id".to_string(), DataType::Int),
579 (
580 "amount".to_string(),
581 DataType::Decimal {
582 precision: Some(10),
583 scale: Some(2),
584 },
585 ),
586 ("status".to_string(), DataType::Varchar(Some(50))),
587 ],
588 )
589 .unwrap();
590 schema
591 .add_table(
592 &["products"],
593 vec![
594 ("id".to_string(), DataType::Int),
595 ("name".to_string(), DataType::Varchar(Some(255))),
596 (
597 "price".to_string(),
598 DataType::Decimal {
599 precision: Some(10),
600 scale: Some(2),
601 },
602 ),
603 ],
604 )
605 .unwrap();
606 schema
607 }
608
609 fn qualify(sql: &str, schema: &MappingSchema) -> String {
610 let stmt = parse(sql, Dialect::Ansi).unwrap();
611 let qualified = qualify_columns(stmt, schema);
612 generate(&qualified, Dialect::Ansi)
613 }
614
615 #[test]
616 fn test_expand_star() {
617 let schema = make_schema();
618 assert_eq!(
619 qualify("SELECT * FROM users", &schema),
620 "SELECT id, name, email FROM users"
621 );
622 }
623
624 #[test]
625 fn test_expand_qualified_wildcard() {
626 let schema = make_schema();
627 assert_eq!(
628 qualify("SELECT users.* FROM users", &schema),
629 "SELECT users.id, users.name, users.email FROM users"
630 );
631 }
632
633 #[test]
634 fn test_expand_star_with_alias() {
635 let schema = make_schema();
636 assert_eq!(
637 qualify("SELECT * FROM users AS u", &schema),
638 "SELECT id, name, email FROM users AS u"
639 );
640 }
641
642 #[test]
643 fn test_expand_qualified_wildcard_alias() {
644 let schema = make_schema();
645 assert_eq!(
646 qualify("SELECT u.* FROM users AS u", &schema),
647 "SELECT u.id, u.name, u.email FROM users AS u"
648 );
649 }
650
651 #[test]
652 fn test_qualify_unqualified_single_table() {
653 let schema = make_schema();
654 assert_eq!(
655 qualify("SELECT id, name FROM users", &schema),
656 "SELECT users.id, users.name FROM users"
657 );
658 }
659
660 #[test]
661 fn test_qualify_unqualified_single_table_alias() {
662 let schema = make_schema();
663 assert_eq!(
664 qualify("SELECT id, name FROM users AS u", &schema),
665 "SELECT u.id, u.name FROM users AS u"
666 );
667 }
668
669 #[test]
670 fn test_qualify_already_qualified() {
671 let schema = make_schema();
672 assert_eq!(
673 qualify("SELECT users.id, users.name FROM users", &schema),
674 "SELECT users.id, users.name FROM users"
675 );
676 }
677
678 #[test]
679 fn test_qualify_join_unambiguous() {
680 let schema = make_schema();
681 assert_eq!(
682 qualify(
683 "SELECT name, amount FROM users JOIN orders ON users.id = orders.user_id",
684 &schema
685 ),
686 "SELECT users.name, orders.amount FROM users INNER JOIN orders ON users.id = orders.user_id"
687 );
688 }
689
690 #[test]
691 fn test_qualify_join_ambiguous_left_unqualified() {
692 let schema = make_schema();
694 let result = qualify(
695 "SELECT id FROM users JOIN orders ON users.id = orders.user_id",
696 &schema,
697 );
698 assert_eq!(
700 result,
701 "SELECT id FROM users INNER JOIN orders ON users.id = orders.user_id"
702 );
703 }
704
705 #[test]
706 fn test_qualify_where_clause() {
707 let schema = make_schema();
708 assert_eq!(
709 qualify(
710 "SELECT name FROM users WHERE email = 'test@test.com'",
711 &schema
712 ),
713 "SELECT users.name FROM users WHERE users.email = 'test@test.com'"
714 );
715 }
716
717 #[test]
718 fn test_qualify_order_by() {
719 let schema = make_schema();
720 assert_eq!(
721 qualify("SELECT name FROM users ORDER BY email", &schema),
722 "SELECT users.name FROM users ORDER BY users.email"
723 );
724 }
725
726 #[test]
727 fn test_qualify_group_by_having() {
728 let schema = make_schema();
729 assert_eq!(
730 qualify(
731 "SELECT status, COUNT(*) FROM orders GROUP BY status HAVING COUNT(*) > 1",
732 &schema
733 ),
734 "SELECT orders.status, COUNT(*) FROM orders GROUP BY orders.status HAVING COUNT(*) > 1"
735 );
736 }
737
738 #[test]
739 fn test_expand_star_join() {
740 let schema = make_schema();
741 let result = qualify(
742 "SELECT * FROM users JOIN orders ON users.id = orders.user_id",
743 &schema,
744 );
745 assert_eq!(
746 result,
747 "SELECT id, name, email, id, user_id, amount, status FROM users INNER JOIN orders ON users.id = orders.user_id"
748 );
749 }
750
751 #[test]
752 fn test_cte_column_resolution() {
753 let schema = make_schema();
754 let result = qualify(
755 "WITH active AS (SELECT id, name FROM users) SELECT id, name FROM active",
756 &schema,
757 );
758 assert_eq!(
759 result,
760 "WITH active AS (SELECT users.id, users.name FROM users) SELECT active.id, active.name FROM active"
761 );
762 }
763
764 #[test]
765 fn test_derived_table_column_resolution() {
766 let schema = make_schema();
767 let result = qualify(
768 "SELECT id FROM (SELECT id, name FROM users) AS sub",
769 &schema,
770 );
771 assert_eq!(
772 result,
773 "SELECT sub.id FROM (SELECT users.id, users.name FROM users) AS sub"
774 );
775 }
776
777 #[test]
778 fn test_preserve_expression_aliases() {
779 let schema = make_schema();
780 assert_eq!(
781 qualify("SELECT name AS user_name FROM users", &schema),
782 "SELECT users.name AS user_name FROM users"
783 );
784 }
785
786 #[test]
787 fn test_qualify_join_on() {
788 let schema = make_schema();
789 assert_eq!(
792 qualify(
793 "SELECT name FROM users JOIN orders ON id = user_id",
794 &schema
795 ),
796 "SELECT users.name FROM users INNER JOIN orders ON id = orders.user_id"
797 );
798 }
799
800 #[test]
801 fn test_no_schema_columns_passthrough() {
802 let schema = make_schema();
804 assert_eq!(
805 qualify("SELECT x, y FROM unknown_table", &schema),
806 "SELECT x, y FROM unknown_table"
807 );
808 }
809}