1use std::collections::HashMap;
11
12use crate::ast::*;
13use crate::dialects::Dialect;
14use crate::schema::{Schema, normalize_identifier};
15
16pub fn qualify_columns<S: Schema>(statement: Statement, schema: &S) -> Statement {
21 let dialect = schema.dialect();
22 match statement {
23 Statement::Select(sel) => {
24 let qualified = qualify_select(sel, schema, dialect, &HashMap::new());
25 Statement::Select(qualified)
26 }
27 Statement::SetOperation(mut set_op) => {
28 set_op.left = Box::new(qualify_columns(*set_op.left, schema));
29 set_op.right = Box::new(qualify_columns(*set_op.right, schema));
30 Statement::SetOperation(set_op)
31 }
32 other => other,
33 }
34}
35
36#[derive(Debug, Clone)]
38struct SourceColumns {
39 columns: Vec<String>,
41}
42
43fn resolve_source_columns<S: Schema>(
45 sel: &SelectStatement,
46 schema: &S,
47 dialect: Dialect,
48 cte_columns: &HashMap<String, Vec<String>>,
49) -> HashMap<String, SourceColumns> {
50 let mut source_map: HashMap<String, SourceColumns> = HashMap::new();
51
52 if let Some(from) = &sel.from {
54 collect_source_columns(&from.source, schema, dialect, cte_columns, &mut source_map);
55 }
56
57 for join in &sel.joins {
59 collect_source_columns(&join.table, schema, dialect, cte_columns, &mut source_map);
60 }
61
62 source_map
63}
64
65fn collect_source_columns<S: Schema>(
67 source: &TableSource,
68 schema: &S,
69 dialect: Dialect,
70 cte_columns: &HashMap<String, Vec<String>>,
71 source_map: &mut HashMap<String, SourceColumns>,
72) {
73 match source {
74 TableSource::Table(table_ref) => {
75 let key = table_ref
76 .alias
77 .as_deref()
78 .unwrap_or(&table_ref.name)
79 .to_string();
80 let norm_key = normalize_identifier(&key, dialect);
81
82 let norm_name = normalize_identifier(&table_ref.name, dialect);
84 if let Some(cols) = cte_columns.get(&norm_name) {
85 source_map.insert(
86 norm_key,
87 SourceColumns {
88 columns: cols.clone(),
89 },
90 );
91 return;
92 }
93
94 let path = build_table_path(table_ref, dialect);
96 let path_refs: Vec<&str> = path.iter().map(|s| s.as_str()).collect();
97
98 if let Ok(cols) = schema.column_names(&path_refs) {
99 source_map.insert(norm_key, SourceColumns { columns: cols });
100 }
101 }
102 TableSource::Subquery { query, alias, .. } => {
103 if let Some(alias) = alias {
104 let norm_alias = normalize_identifier(alias, dialect);
105 let cols = extract_output_columns(query, schema, dialect, cte_columns);
106 source_map.insert(norm_alias, SourceColumns { columns: cols });
107 }
108 }
109 TableSource::Lateral { source: inner } => {
110 collect_source_columns(inner, schema, dialect, cte_columns, source_map);
111 }
112 TableSource::Pivot { source, alias, .. } | TableSource::Unpivot { source, alias, .. } => {
113 collect_source_columns(source, schema, dialect, cte_columns, source_map);
114 if let Some(alias) = alias {
115 let norm_alias = normalize_identifier(alias, dialect);
116 source_map.insert(norm_alias, SourceColumns { columns: vec![] });
117 }
118 }
119 TableSource::Unnest { alias, .. } => {
120 if let Some(alias) = alias {
121 let norm_alias = normalize_identifier(alias, dialect);
122 source_map.insert(norm_alias, SourceColumns { columns: vec![] });
124 }
125 }
126 TableSource::TableFunction { alias, .. } => {
127 if let Some(alias) = alias {
128 let norm_alias = normalize_identifier(alias, dialect);
129 source_map.insert(norm_alias, SourceColumns { columns: vec![] });
130 }
131 }
132 }
133}
134
135fn build_table_path(table_ref: &TableRef, dialect: Dialect) -> Vec<String> {
137 let mut path = Vec::new();
138 if let Some(cat) = &table_ref.catalog {
139 path.push(normalize_identifier(cat, dialect));
140 }
141 if let Some(sch) = &table_ref.schema {
142 path.push(normalize_identifier(sch, dialect));
143 }
144 path.push(normalize_identifier(&table_ref.name, dialect));
145 path
146}
147
148fn extract_output_columns<S: Schema>(
150 stmt: &Statement,
151 schema: &S,
152 dialect: Dialect,
153 cte_columns: &HashMap<String, Vec<String>>,
154) -> Vec<String> {
155 match stmt {
156 Statement::Select(sel) => {
157 let inner_sources = resolve_source_columns(sel, schema, dialect, cte_columns);
158 let mut cols = Vec::new();
159 for item in &sel.columns {
160 match item {
161 SelectItem::Wildcard => {
162 for_each_source_ordered(sel, dialect, &inner_sources, |sc| {
164 cols.extend(sc.columns.iter().cloned());
165 });
166 }
167 SelectItem::QualifiedWildcard { table } => {
168 let norm_table = normalize_identifier(table, dialect);
169 if let Some(sc) = inner_sources.get(&norm_table) {
170 cols.extend(sc.columns.iter().cloned());
171 }
172 }
173 SelectItem::Expr { alias, expr, .. } => {
174 if let Some(alias) = alias {
175 cols.push(alias.clone());
176 } else {
177 cols.push(expr_output_name(expr));
178 }
179 }
180 }
181 }
182 cols
183 }
184 Statement::SetOperation(set_op) => {
185 extract_output_columns(&set_op.left, schema, dialect, cte_columns)
187 }
188 _ => vec![],
189 }
190}
191
192fn expr_output_name(expr: &Expr) -> String {
194 match expr {
195 Expr::Column { name, .. } => name.clone(),
196 Expr::Function { name, .. } => name.clone(),
197 Expr::TypedFunction { .. } => "_col".to_string(),
198 _ => "_col".to_string(),
199 }
200}
201
202fn for_each_source_ordered<F>(
204 sel: &SelectStatement,
205 dialect: Dialect,
206 source_map: &HashMap<String, SourceColumns>,
207 mut callback: F,
208) where
209 F: FnMut(&SourceColumns),
210{
211 if let Some(from) = &sel.from {
213 let key = source_key_for(&from.source, dialect);
214 if let Some(sc) = source_map.get(&key) {
215 callback(sc);
216 }
217 }
218 for join in &sel.joins {
220 let key = source_key_for(&join.table, dialect);
221 if let Some(sc) = source_map.get(&key) {
222 callback(sc);
223 }
224 }
225}
226
227fn source_key_for(source: &TableSource, dialect: Dialect) -> String {
229 match source {
230 TableSource::Table(tr) => {
231 let name = tr.alias.as_deref().unwrap_or(&tr.name);
232 normalize_identifier(name, dialect)
233 }
234 TableSource::Subquery { alias, .. } => alias
235 .as_deref()
236 .map(|a| normalize_identifier(a, dialect))
237 .unwrap_or_default(),
238 TableSource::Lateral { source } => source_key_for(source, dialect),
239 TableSource::Pivot { source, alias, .. } | TableSource::Unpivot { source, alias, .. } => {
240 if let Some(a) = alias {
241 normalize_identifier(a, dialect)
242 } else {
243 source_key_for(source, dialect)
244 }
245 }
246 TableSource::Unnest { alias, .. } | TableSource::TableFunction { alias, .. } => alias
247 .as_deref()
248 .map(|a| normalize_identifier(a, dialect))
249 .unwrap_or_default(),
250 }
251}
252
253fn qualify_select<S: Schema>(
255 mut sel: SelectStatement,
256 schema: &S,
257 dialect: Dialect,
258 outer_cte_columns: &HashMap<String, Vec<String>>,
259) -> SelectStatement {
260 let mut cte_columns = outer_cte_columns.clone();
262 for cte in &sel.ctes {
263 let cols = if !cte.columns.is_empty() {
264 cte.columns.clone()
266 } else {
267 extract_output_columns(&cte.query, schema, dialect, &cte_columns)
268 };
269 let norm_name = normalize_identifier(&cte.name, dialect);
270 cte_columns.insert(norm_name, cols);
271 }
272
273 sel.ctes = sel
275 .ctes
276 .into_iter()
277 .map(|mut cte| {
278 cte.query = Box::new(qualify_columns(*cte.query, schema));
279 cte
280 })
281 .collect();
282
283 if let Some(ref mut from) = sel.from {
285 qualify_table_source(&mut from.source, schema, dialect, &cte_columns);
286 }
287 for join in &mut sel.joins {
288 qualify_table_source(&mut join.table, schema, dialect, &cte_columns);
289 }
290
291 let source_map = resolve_source_columns(&sel, schema, dialect, &cte_columns);
293
294 let mut new_columns = Vec::new();
296 let old_columns = std::mem::take(&mut sel.columns);
297 for item in old_columns {
298 match item {
299 SelectItem::Wildcard => {
300 for_each_source_ordered(&sel, dialect, &source_map, |sc| {
302 for col_name in &sc.columns {
303 new_columns.push(SelectItem::Expr {
304 expr: Expr::Column {
305 table: None,
306 name: col_name.clone(),
307 quote_style: QuoteStyle::None,
308 table_quote_style: QuoteStyle::None,
309 },
310 alias: None,
311 alias_quote_style: QuoteStyle::None,
312 });
313 }
314 });
315 }
316 SelectItem::QualifiedWildcard { table } => {
317 let norm_table = normalize_identifier(&table, dialect);
318 if let Some(sc) = source_map.get(&norm_table) {
319 for col_name in &sc.columns {
320 new_columns.push(SelectItem::Expr {
321 expr: Expr::Column {
322 table: Some(table.clone()),
323 name: col_name.clone(),
324 quote_style: QuoteStyle::None,
325 table_quote_style: QuoteStyle::None,
326 },
327 alias: None,
328 alias_quote_style: QuoteStyle::None,
329 });
330 }
331 } else {
332 new_columns.push(SelectItem::QualifiedWildcard { table });
334 }
335 }
336 SelectItem::Expr { expr, alias, alias_quote_style, .. } => {
337 let qualified_expr = qualify_expr(expr, &source_map, schema, dialect, &cte_columns);
338 new_columns.push(SelectItem::Expr {
339 expr: qualified_expr,
340 alias,
341 alias_quote_style,
342 });
343 }
344 }
345 }
346 sel.columns = new_columns;
347
348 if let Some(wh) = sel.where_clause {
350 sel.where_clause = Some(qualify_expr(wh, &source_map, schema, dialect, &cte_columns));
351 }
352 sel.group_by = sel
353 .group_by
354 .into_iter()
355 .map(|e| qualify_expr(e, &source_map, schema, dialect, &cte_columns))
356 .collect();
357 if let Some(having) = sel.having {
358 sel.having = Some(qualify_expr(
359 having,
360 &source_map,
361 schema,
362 dialect,
363 &cte_columns,
364 ));
365 }
366 sel.order_by = sel
367 .order_by
368 .into_iter()
369 .map(|mut item| {
370 item.expr = qualify_expr(item.expr, &source_map, schema, dialect, &cte_columns);
371 item
372 })
373 .collect();
374 if let Some(qualify) = sel.qualify {
375 sel.qualify = Some(qualify_expr(
376 qualify,
377 &source_map,
378 schema,
379 dialect,
380 &cte_columns,
381 ));
382 }
383
384 for join in &mut sel.joins {
386 if let Some(on) = join.on.take() {
387 join.on = Some(qualify_expr(on, &source_map, schema, dialect, &cte_columns));
388 }
389 }
390
391 sel
392}
393
394fn qualify_table_source<S: Schema>(
396 source: &mut TableSource,
397 schema: &S,
398 dialect: Dialect,
399 cte_columns: &HashMap<String, Vec<String>>,
400) {
401 match source {
402 TableSource::Subquery { query, .. } => {
403 *query = Box::new(qualify_columns_inner(
404 *query.clone(),
405 schema,
406 dialect,
407 cte_columns,
408 ));
409 }
410 TableSource::Lateral { source: inner } => {
411 qualify_table_source(inner, schema, dialect, cte_columns);
412 }
413 TableSource::Pivot { source, .. } | TableSource::Unpivot { source, .. } => {
414 qualify_table_source(source, schema, dialect, cte_columns);
415 }
416 _ => {}
417 }
418}
419
420fn qualify_columns_inner<S: Schema>(
422 statement: Statement,
423 schema: &S,
424 dialect: Dialect,
425 cte_columns: &HashMap<String, Vec<String>>,
426) -> Statement {
427 match statement {
428 Statement::Select(sel) => {
429 Statement::Select(qualify_select(sel, schema, dialect, cte_columns))
430 }
431 Statement::SetOperation(mut set_op) => {
432 set_op.left = Box::new(qualify_columns_inner(
433 *set_op.left,
434 schema,
435 dialect,
436 cte_columns,
437 ));
438 set_op.right = Box::new(qualify_columns_inner(
439 *set_op.right,
440 schema,
441 dialect,
442 cte_columns,
443 ));
444 Statement::SetOperation(set_op)
445 }
446 other => other,
447 }
448}
449
450fn qualify_expr<S: Schema>(
453 expr: Expr,
454 source_map: &HashMap<String, SourceColumns>,
455 schema: &S,
456 dialect: Dialect,
457 cte_columns: &HashMap<String, Vec<String>>,
458) -> Expr {
459 expr.transform(&|e| match e {
460 Expr::Column {
461 table: None,
462 name,
463 quote_style,
464 table_quote_style,
465 } => {
466 let norm_name = normalize_identifier(&name, dialect);
467 let resolved_source = resolve_column(&norm_name, source_map);
469 if let Some(source_name) = resolved_source {
470 Expr::Column {
471 table: Some(source_name),
472 name,
473 quote_style,
474 table_quote_style,
475 }
476 } else {
477 Expr::Column {
480 table: None,
481 name,
482 quote_style,
483 table_quote_style,
484 }
485 }
486 }
487 Expr::InSubquery {
489 expr,
490 subquery,
491 negated,
492 } => Expr::InSubquery {
493 expr,
494 subquery: Box::new(qualify_columns_inner(
495 *subquery,
496 schema,
497 dialect,
498 cte_columns,
499 )),
500 negated,
501 },
502 Expr::Subquery(stmt) => Expr::Subquery(Box::new(qualify_columns_inner(
503 *stmt,
504 schema,
505 dialect,
506 cte_columns,
507 ))),
508 Expr::Exists { subquery, negated } => Expr::Exists {
509 subquery: Box::new(qualify_columns_inner(
510 *subquery,
511 schema,
512 dialect,
513 cte_columns,
514 )),
515 negated,
516 },
517 other => other,
518 })
519}
520
521fn resolve_column(
525 norm_col_name: &str,
526 source_map: &HashMap<String, SourceColumns>,
527) -> Option<String> {
528 let mut matches: Vec<&str> = Vec::new();
529 for (source_name, sc) in source_map {
530 if sc
531 .columns
532 .iter()
533 .any(|c| c.eq_ignore_ascii_case(norm_col_name))
534 {
535 matches.push(source_name);
536 }
537 }
538 if matches.len() == 1 {
539 Some(matches[0].to_string())
540 } else {
541 None
542 }
543}
544
545#[cfg(test)]
550mod tests {
551 use super::*;
552 use crate::generator::generate;
553 use crate::parser::parse;
554 use crate::schema::MappingSchema;
555
556 fn make_schema() -> MappingSchema {
557 let mut schema = MappingSchema::new(Dialect::Ansi);
558 schema
559 .add_table(
560 &["users"],
561 vec![
562 ("id".to_string(), DataType::Int),
563 ("name".to_string(), DataType::Varchar(Some(255))),
564 ("email".to_string(), DataType::Text),
565 ],
566 )
567 .unwrap();
568 schema
569 .add_table(
570 &["orders"],
571 vec![
572 ("id".to_string(), DataType::Int),
573 ("user_id".to_string(), DataType::Int),
574 (
575 "amount".to_string(),
576 DataType::Decimal {
577 precision: Some(10),
578 scale: Some(2),
579 },
580 ),
581 ("status".to_string(), DataType::Varchar(Some(50))),
582 ],
583 )
584 .unwrap();
585 schema
586 .add_table(
587 &["products"],
588 vec![
589 ("id".to_string(), DataType::Int),
590 ("name".to_string(), DataType::Varchar(Some(255))),
591 (
592 "price".to_string(),
593 DataType::Decimal {
594 precision: Some(10),
595 scale: Some(2),
596 },
597 ),
598 ],
599 )
600 .unwrap();
601 schema
602 }
603
604 fn qualify(sql: &str, schema: &MappingSchema) -> String {
605 let stmt = parse(sql, Dialect::Ansi).unwrap();
606 let qualified = qualify_columns(stmt, schema);
607 generate(&qualified, Dialect::Ansi)
608 }
609
610 #[test]
611 fn test_expand_star() {
612 let schema = make_schema();
613 assert_eq!(
614 qualify("SELECT * FROM users", &schema),
615 "SELECT id, name, email FROM users"
616 );
617 }
618
619 #[test]
620 fn test_expand_qualified_wildcard() {
621 let schema = make_schema();
622 assert_eq!(
623 qualify("SELECT users.* FROM users", &schema),
624 "SELECT users.id, users.name, users.email FROM users"
625 );
626 }
627
628 #[test]
629 fn test_expand_star_with_alias() {
630 let schema = make_schema();
631 assert_eq!(
632 qualify("SELECT * FROM users AS u", &schema),
633 "SELECT id, name, email FROM users AS u"
634 );
635 }
636
637 #[test]
638 fn test_expand_qualified_wildcard_alias() {
639 let schema = make_schema();
640 assert_eq!(
641 qualify("SELECT u.* FROM users AS u", &schema),
642 "SELECT u.id, u.name, u.email FROM users AS u"
643 );
644 }
645
646 #[test]
647 fn test_qualify_unqualified_single_table() {
648 let schema = make_schema();
649 assert_eq!(
650 qualify("SELECT id, name FROM users", &schema),
651 "SELECT users.id, users.name FROM users"
652 );
653 }
654
655 #[test]
656 fn test_qualify_unqualified_single_table_alias() {
657 let schema = make_schema();
658 assert_eq!(
659 qualify("SELECT id, name FROM users AS u", &schema),
660 "SELECT u.id, u.name FROM users AS u"
661 );
662 }
663
664 #[test]
665 fn test_qualify_already_qualified() {
666 let schema = make_schema();
667 assert_eq!(
668 qualify("SELECT users.id, users.name FROM users", &schema),
669 "SELECT users.id, users.name FROM users"
670 );
671 }
672
673 #[test]
674 fn test_qualify_join_unambiguous() {
675 let schema = make_schema();
676 assert_eq!(
677 qualify(
678 "SELECT name, amount FROM users JOIN orders ON users.id = orders.user_id",
679 &schema
680 ),
681 "SELECT users.name, orders.amount FROM users INNER JOIN orders ON users.id = orders.user_id"
682 );
683 }
684
685 #[test]
686 fn test_qualify_join_ambiguous_left_unqualified() {
687 let schema = make_schema();
689 let result = qualify(
690 "SELECT id FROM users JOIN orders ON users.id = orders.user_id",
691 &schema,
692 );
693 assert_eq!(
695 result,
696 "SELECT id FROM users INNER JOIN orders ON users.id = orders.user_id"
697 );
698 }
699
700 #[test]
701 fn test_qualify_where_clause() {
702 let schema = make_schema();
703 assert_eq!(
704 qualify(
705 "SELECT name FROM users WHERE email = 'test@test.com'",
706 &schema
707 ),
708 "SELECT users.name FROM users WHERE users.email = 'test@test.com'"
709 );
710 }
711
712 #[test]
713 fn test_qualify_order_by() {
714 let schema = make_schema();
715 assert_eq!(
716 qualify("SELECT name FROM users ORDER BY email", &schema),
717 "SELECT users.name FROM users ORDER BY users.email"
718 );
719 }
720
721 #[test]
722 fn test_qualify_group_by_having() {
723 let schema = make_schema();
724 assert_eq!(
725 qualify(
726 "SELECT status, COUNT(*) FROM orders GROUP BY status HAVING COUNT(*) > 1",
727 &schema
728 ),
729 "SELECT orders.status, COUNT(*) FROM orders GROUP BY orders.status HAVING COUNT(*) > 1"
730 );
731 }
732
733 #[test]
734 fn test_expand_star_join() {
735 let schema = make_schema();
736 let result = qualify(
737 "SELECT * FROM users JOIN orders ON users.id = orders.user_id",
738 &schema,
739 );
740 assert_eq!(
741 result,
742 "SELECT id, name, email, id, user_id, amount, status FROM users INNER JOIN orders ON users.id = orders.user_id"
743 );
744 }
745
746 #[test]
747 fn test_cte_column_resolution() {
748 let schema = make_schema();
749 let result = qualify(
750 "WITH active AS (SELECT id, name FROM users) SELECT id, name FROM active",
751 &schema,
752 );
753 assert_eq!(
754 result,
755 "WITH active AS (SELECT users.id, users.name FROM users) SELECT active.id, active.name FROM active"
756 );
757 }
758
759 #[test]
760 fn test_derived_table_column_resolution() {
761 let schema = make_schema();
762 let result = qualify(
763 "SELECT id FROM (SELECT id, name FROM users) AS sub",
764 &schema,
765 );
766 assert_eq!(
767 result,
768 "SELECT sub.id FROM (SELECT users.id, users.name FROM users) AS sub"
769 );
770 }
771
772 #[test]
773 fn test_preserve_expression_aliases() {
774 let schema = make_schema();
775 assert_eq!(
776 qualify("SELECT name AS user_name FROM users", &schema),
777 "SELECT users.name AS user_name FROM users"
778 );
779 }
780
781 #[test]
782 fn test_qualify_join_on() {
783 let schema = make_schema();
784 assert_eq!(
787 qualify(
788 "SELECT name FROM users JOIN orders ON id = user_id",
789 &schema
790 ),
791 "SELECT users.name FROM users INNER JOIN orders ON id = orders.user_id"
792 );
793 }
794
795 #[test]
796 fn test_no_schema_columns_passthrough() {
797 let schema = make_schema();
799 assert_eq!(
800 qualify("SELECT x, y FROM unknown_table", &schema),
801 "SELECT x, y FROM unknown_table"
802 );
803 }
804}