1use std::collections::HashMap;
11
12use crate::ast::*;
13use crate::dialects::Dialect;
14use crate::schema::{Schema, normalize_identifier};
15
16pub fn qualify_columns<S: Schema>(statement: Statement, schema: &S) -> Statement {
21 let dialect = schema.dialect();
22 match statement {
23 Statement::Select(sel) => {
24 let qualified = qualify_select(sel, schema, dialect, &HashMap::new());
25 Statement::Select(qualified)
26 }
27 Statement::SetOperation(mut set_op) => {
28 set_op.left = Box::new(qualify_columns(*set_op.left, schema));
29 set_op.right = Box::new(qualify_columns(*set_op.right, schema));
30 Statement::SetOperation(set_op)
31 }
32 other => other,
33 }
34}
35
36#[derive(Debug, Clone)]
38struct SourceColumns {
39 columns: Vec<String>,
41}
42
43fn resolve_source_columns<S: Schema>(
45 sel: &SelectStatement,
46 schema: &S,
47 dialect: Dialect,
48 cte_columns: &HashMap<String, Vec<String>>,
49) -> HashMap<String, SourceColumns> {
50 let mut source_map: HashMap<String, SourceColumns> = HashMap::new();
51
52 if let Some(from) = &sel.from {
54 collect_source_columns(&from.source, schema, dialect, cte_columns, &mut source_map);
55 }
56
57 for join in &sel.joins {
59 collect_source_columns(&join.table, schema, dialect, cte_columns, &mut source_map);
60 }
61
62 source_map
63}
64
65fn collect_source_columns<S: Schema>(
67 source: &TableSource,
68 schema: &S,
69 dialect: Dialect,
70 cte_columns: &HashMap<String, Vec<String>>,
71 source_map: &mut HashMap<String, SourceColumns>,
72) {
73 match source {
74 TableSource::Table(table_ref) => {
75 let key = table_ref
76 .alias
77 .as_deref()
78 .unwrap_or(&table_ref.name)
79 .to_string();
80 let norm_key = normalize_identifier(&key, dialect);
81
82 let norm_name = normalize_identifier(&table_ref.name, dialect);
84 if let Some(cols) = cte_columns.get(&norm_name) {
85 source_map.insert(
86 norm_key,
87 SourceColumns {
88 columns: cols.clone(),
89 },
90 );
91 return;
92 }
93
94 let path = build_table_path(table_ref, dialect);
96 let path_refs: Vec<&str> = path.iter().map(|s| s.as_str()).collect();
97
98 if let Ok(cols) = schema.column_names(&path_refs) {
99 source_map.insert(norm_key, SourceColumns { columns: cols });
100 }
101 }
102 TableSource::Subquery { query, alias } => {
103 if let Some(alias) = alias {
104 let norm_alias = normalize_identifier(alias, dialect);
105 let cols = extract_output_columns(query, schema, dialect, cte_columns);
106 source_map.insert(norm_alias, SourceColumns { columns: cols });
107 }
108 }
109 TableSource::Lateral { source: inner } => {
110 collect_source_columns(inner, schema, dialect, cte_columns, source_map);
111 }
112 TableSource::Unnest { alias, .. } => {
113 if let Some(alias) = alias {
114 let norm_alias = normalize_identifier(alias, dialect);
115 source_map.insert(norm_alias, SourceColumns { columns: vec![] });
117 }
118 }
119 TableSource::TableFunction { alias, .. } => {
120 if let Some(alias) = alias {
121 let norm_alias = normalize_identifier(alias, dialect);
122 source_map.insert(norm_alias, SourceColumns { columns: vec![] });
123 }
124 }
125 }
126}
127
128fn build_table_path(table_ref: &TableRef, dialect: Dialect) -> Vec<String> {
130 let mut path = Vec::new();
131 if let Some(cat) = &table_ref.catalog {
132 path.push(normalize_identifier(cat, dialect));
133 }
134 if let Some(sch) = &table_ref.schema {
135 path.push(normalize_identifier(sch, dialect));
136 }
137 path.push(normalize_identifier(&table_ref.name, dialect));
138 path
139}
140
141fn extract_output_columns<S: Schema>(
143 stmt: &Statement,
144 schema: &S,
145 dialect: Dialect,
146 cte_columns: &HashMap<String, Vec<String>>,
147) -> Vec<String> {
148 match stmt {
149 Statement::Select(sel) => {
150 let inner_sources = resolve_source_columns(sel, schema, dialect, cte_columns);
151 let mut cols = Vec::new();
152 for item in &sel.columns {
153 match item {
154 SelectItem::Wildcard => {
155 for_each_source_ordered(sel, dialect, &inner_sources, |sc| {
157 cols.extend(sc.columns.iter().cloned());
158 });
159 }
160 SelectItem::QualifiedWildcard { table } => {
161 let norm_table = normalize_identifier(table, dialect);
162 if let Some(sc) = inner_sources.get(&norm_table) {
163 cols.extend(sc.columns.iter().cloned());
164 }
165 }
166 SelectItem::Expr { alias, expr } => {
167 if let Some(alias) = alias {
168 cols.push(alias.clone());
169 } else {
170 cols.push(expr_output_name(expr));
171 }
172 }
173 }
174 }
175 cols
176 }
177 Statement::SetOperation(set_op) => {
178 extract_output_columns(&set_op.left, schema, dialect, cte_columns)
180 }
181 _ => vec![],
182 }
183}
184
185fn expr_output_name(expr: &Expr) -> String {
187 match expr {
188 Expr::Column { name, .. } => name.clone(),
189 Expr::Function { name, .. } => name.clone(),
190 Expr::TypedFunction { .. } => "_col".to_string(),
191 _ => "_col".to_string(),
192 }
193}
194
195fn for_each_source_ordered<F>(
197 sel: &SelectStatement,
198 dialect: Dialect,
199 source_map: &HashMap<String, SourceColumns>,
200 mut callback: F,
201) where
202 F: FnMut(&SourceColumns),
203{
204 if let Some(from) = &sel.from {
206 let key = source_key_for(&from.source, dialect);
207 if let Some(sc) = source_map.get(&key) {
208 callback(sc);
209 }
210 }
211 for join in &sel.joins {
213 let key = source_key_for(&join.table, dialect);
214 if let Some(sc) = source_map.get(&key) {
215 callback(sc);
216 }
217 }
218}
219
220fn source_key_for(source: &TableSource, dialect: Dialect) -> String {
222 match source {
223 TableSource::Table(tr) => {
224 let name = tr.alias.as_deref().unwrap_or(&tr.name);
225 normalize_identifier(name, dialect)
226 }
227 TableSource::Subquery { alias, .. } => alias
228 .as_deref()
229 .map(|a| normalize_identifier(a, dialect))
230 .unwrap_or_default(),
231 TableSource::Lateral { source } => source_key_for(source, dialect),
232 TableSource::Unnest { alias, .. } | TableSource::TableFunction { alias, .. } => alias
233 .as_deref()
234 .map(|a| normalize_identifier(a, dialect))
235 .unwrap_or_default(),
236 }
237}
238
239fn qualify_select<S: Schema>(
241 mut sel: SelectStatement,
242 schema: &S,
243 dialect: Dialect,
244 outer_cte_columns: &HashMap<String, Vec<String>>,
245) -> SelectStatement {
246 let mut cte_columns = outer_cte_columns.clone();
248 for cte in &sel.ctes {
249 let cols = if !cte.columns.is_empty() {
250 cte.columns.clone()
252 } else {
253 extract_output_columns(&cte.query, schema, dialect, &cte_columns)
254 };
255 let norm_name = normalize_identifier(&cte.name, dialect);
256 cte_columns.insert(norm_name, cols);
257 }
258
259 sel.ctes = sel
261 .ctes
262 .into_iter()
263 .map(|mut cte| {
264 cte.query = Box::new(qualify_columns(*cte.query, schema));
265 cte
266 })
267 .collect();
268
269 if let Some(ref mut from) = sel.from {
271 qualify_table_source(&mut from.source, schema, dialect, &cte_columns);
272 }
273 for join in &mut sel.joins {
274 qualify_table_source(&mut join.table, schema, dialect, &cte_columns);
275 }
276
277 let source_map = resolve_source_columns(&sel, schema, dialect, &cte_columns);
279
280 let mut new_columns = Vec::new();
282 let old_columns = std::mem::take(&mut sel.columns);
283 for item in old_columns {
284 match item {
285 SelectItem::Wildcard => {
286 for_each_source_ordered(&sel, dialect, &source_map, |sc| {
288 for col_name in &sc.columns {
289 new_columns.push(SelectItem::Expr {
290 expr: Expr::Column {
291 table: None,
292 name: col_name.clone(),
293 quote_style: QuoteStyle::None,
294 table_quote_style: QuoteStyle::None,
295 },
296 alias: None,
297 });
298 }
299 });
300 }
301 SelectItem::QualifiedWildcard { table } => {
302 let norm_table = normalize_identifier(&table, dialect);
303 if let Some(sc) = source_map.get(&norm_table) {
304 for col_name in &sc.columns {
305 new_columns.push(SelectItem::Expr {
306 expr: Expr::Column {
307 table: Some(table.clone()),
308 name: col_name.clone(),
309 quote_style: QuoteStyle::None,
310 table_quote_style: QuoteStyle::None,
311 },
312 alias: None,
313 });
314 }
315 } else {
316 new_columns.push(SelectItem::QualifiedWildcard { table });
318 }
319 }
320 SelectItem::Expr { expr, alias } => {
321 let qualified_expr = qualify_expr(expr, &source_map, schema, dialect, &cte_columns);
322 new_columns.push(SelectItem::Expr {
323 expr: qualified_expr,
324 alias,
325 });
326 }
327 }
328 }
329 sel.columns = new_columns;
330
331 if let Some(wh) = sel.where_clause {
333 sel.where_clause = Some(qualify_expr(wh, &source_map, schema, dialect, &cte_columns));
334 }
335 sel.group_by = sel
336 .group_by
337 .into_iter()
338 .map(|e| qualify_expr(e, &source_map, schema, dialect, &cte_columns))
339 .collect();
340 if let Some(having) = sel.having {
341 sel.having = Some(qualify_expr(
342 having,
343 &source_map,
344 schema,
345 dialect,
346 &cte_columns,
347 ));
348 }
349 sel.order_by = sel
350 .order_by
351 .into_iter()
352 .map(|mut item| {
353 item.expr = qualify_expr(item.expr, &source_map, schema, dialect, &cte_columns);
354 item
355 })
356 .collect();
357 if let Some(qualify) = sel.qualify {
358 sel.qualify = Some(qualify_expr(
359 qualify,
360 &source_map,
361 schema,
362 dialect,
363 &cte_columns,
364 ));
365 }
366
367 for join in &mut sel.joins {
369 if let Some(on) = join.on.take() {
370 join.on = Some(qualify_expr(on, &source_map, schema, dialect, &cte_columns));
371 }
372 }
373
374 sel
375}
376
377fn qualify_table_source<S: Schema>(
379 source: &mut TableSource,
380 schema: &S,
381 dialect: Dialect,
382 cte_columns: &HashMap<String, Vec<String>>,
383) {
384 match source {
385 TableSource::Subquery { query, .. } => {
386 *query = Box::new(qualify_columns_inner(
387 *query.clone(),
388 schema,
389 dialect,
390 cte_columns,
391 ));
392 }
393 TableSource::Lateral { source: inner } => {
394 qualify_table_source(inner, schema, dialect, cte_columns);
395 }
396 _ => {}
397 }
398}
399
400fn qualify_columns_inner<S: Schema>(
402 statement: Statement,
403 schema: &S,
404 dialect: Dialect,
405 cte_columns: &HashMap<String, Vec<String>>,
406) -> Statement {
407 match statement {
408 Statement::Select(sel) => {
409 Statement::Select(qualify_select(sel, schema, dialect, cte_columns))
410 }
411 Statement::SetOperation(mut set_op) => {
412 set_op.left = Box::new(qualify_columns_inner(
413 *set_op.left,
414 schema,
415 dialect,
416 cte_columns,
417 ));
418 set_op.right = Box::new(qualify_columns_inner(
419 *set_op.right,
420 schema,
421 dialect,
422 cte_columns,
423 ));
424 Statement::SetOperation(set_op)
425 }
426 other => other,
427 }
428}
429
430fn qualify_expr<S: Schema>(
433 expr: Expr,
434 source_map: &HashMap<String, SourceColumns>,
435 schema: &S,
436 dialect: Dialect,
437 cte_columns: &HashMap<String, Vec<String>>,
438) -> Expr {
439 expr.transform(&|e| match e {
440 Expr::Column {
441 table: None,
442 name,
443 quote_style,
444 table_quote_style,
445 } => {
446 let norm_name = normalize_identifier(&name, dialect);
447 let resolved_source = resolve_column(&norm_name, source_map);
449 if let Some(source_name) = resolved_source {
450 Expr::Column {
451 table: Some(source_name),
452 name,
453 quote_style,
454 table_quote_style,
455 }
456 } else {
457 Expr::Column {
460 table: None,
461 name,
462 quote_style,
463 table_quote_style,
464 }
465 }
466 }
467 Expr::InSubquery {
469 expr,
470 subquery,
471 negated,
472 } => Expr::InSubquery {
473 expr,
474 subquery: Box::new(qualify_columns_inner(
475 *subquery,
476 schema,
477 dialect,
478 cte_columns,
479 )),
480 negated,
481 },
482 Expr::Subquery(stmt) => Expr::Subquery(Box::new(qualify_columns_inner(
483 *stmt,
484 schema,
485 dialect,
486 cte_columns,
487 ))),
488 Expr::Exists { subquery, negated } => Expr::Exists {
489 subquery: Box::new(qualify_columns_inner(
490 *subquery,
491 schema,
492 dialect,
493 cte_columns,
494 )),
495 negated,
496 },
497 other => other,
498 })
499}
500
501fn resolve_column(
505 norm_col_name: &str,
506 source_map: &HashMap<String, SourceColumns>,
507) -> Option<String> {
508 let mut matches: Vec<&str> = Vec::new();
509 for (source_name, sc) in source_map {
510 if sc
511 .columns
512 .iter()
513 .any(|c| c.eq_ignore_ascii_case(norm_col_name))
514 {
515 matches.push(source_name);
516 }
517 }
518 if matches.len() == 1 {
519 Some(matches[0].to_string())
520 } else {
521 None
522 }
523}
524
525#[cfg(test)]
530mod tests {
531 use super::*;
532 use crate::generator::generate;
533 use crate::parser::parse;
534 use crate::schema::MappingSchema;
535
536 fn make_schema() -> MappingSchema {
537 let mut schema = MappingSchema::new(Dialect::Ansi);
538 schema
539 .add_table(
540 &["users"],
541 vec![
542 ("id".to_string(), DataType::Int),
543 ("name".to_string(), DataType::Varchar(Some(255))),
544 ("email".to_string(), DataType::Text),
545 ],
546 )
547 .unwrap();
548 schema
549 .add_table(
550 &["orders"],
551 vec![
552 ("id".to_string(), DataType::Int),
553 ("user_id".to_string(), DataType::Int),
554 (
555 "amount".to_string(),
556 DataType::Decimal {
557 precision: Some(10),
558 scale: Some(2),
559 },
560 ),
561 ("status".to_string(), DataType::Varchar(Some(50))),
562 ],
563 )
564 .unwrap();
565 schema
566 .add_table(
567 &["products"],
568 vec![
569 ("id".to_string(), DataType::Int),
570 ("name".to_string(), DataType::Varchar(Some(255))),
571 (
572 "price".to_string(),
573 DataType::Decimal {
574 precision: Some(10),
575 scale: Some(2),
576 },
577 ),
578 ],
579 )
580 .unwrap();
581 schema
582 }
583
584 fn qualify(sql: &str, schema: &MappingSchema) -> String {
585 let stmt = parse(sql, Dialect::Ansi).unwrap();
586 let qualified = qualify_columns(stmt, schema);
587 generate(&qualified, Dialect::Ansi)
588 }
589
590 #[test]
591 fn test_expand_star() {
592 let schema = make_schema();
593 assert_eq!(
594 qualify("SELECT * FROM users", &schema),
595 "SELECT id, name, email FROM users"
596 );
597 }
598
599 #[test]
600 fn test_expand_qualified_wildcard() {
601 let schema = make_schema();
602 assert_eq!(
603 qualify("SELECT users.* FROM users", &schema),
604 "SELECT users.id, users.name, users.email FROM users"
605 );
606 }
607
608 #[test]
609 fn test_expand_star_with_alias() {
610 let schema = make_schema();
611 assert_eq!(
612 qualify("SELECT * FROM users AS u", &schema),
613 "SELECT id, name, email FROM users AS u"
614 );
615 }
616
617 #[test]
618 fn test_expand_qualified_wildcard_alias() {
619 let schema = make_schema();
620 assert_eq!(
621 qualify("SELECT u.* FROM users AS u", &schema),
622 "SELECT u.id, u.name, u.email FROM users AS u"
623 );
624 }
625
626 #[test]
627 fn test_qualify_unqualified_single_table() {
628 let schema = make_schema();
629 assert_eq!(
630 qualify("SELECT id, name FROM users", &schema),
631 "SELECT users.id, users.name FROM users"
632 );
633 }
634
635 #[test]
636 fn test_qualify_unqualified_single_table_alias() {
637 let schema = make_schema();
638 assert_eq!(
639 qualify("SELECT id, name FROM users AS u", &schema),
640 "SELECT u.id, u.name FROM users AS u"
641 );
642 }
643
644 #[test]
645 fn test_qualify_already_qualified() {
646 let schema = make_schema();
647 assert_eq!(
648 qualify("SELECT users.id, users.name FROM users", &schema),
649 "SELECT users.id, users.name FROM users"
650 );
651 }
652
653 #[test]
654 fn test_qualify_join_unambiguous() {
655 let schema = make_schema();
656 assert_eq!(
657 qualify(
658 "SELECT name, amount FROM users JOIN orders ON users.id = orders.user_id",
659 &schema
660 ),
661 "SELECT users.name, orders.amount FROM users INNER JOIN orders ON users.id = orders.user_id"
662 );
663 }
664
665 #[test]
666 fn test_qualify_join_ambiguous_left_unqualified() {
667 let schema = make_schema();
669 let result = qualify(
670 "SELECT id FROM users JOIN orders ON users.id = orders.user_id",
671 &schema,
672 );
673 assert_eq!(
675 result,
676 "SELECT id FROM users INNER JOIN orders ON users.id = orders.user_id"
677 );
678 }
679
680 #[test]
681 fn test_qualify_where_clause() {
682 let schema = make_schema();
683 assert_eq!(
684 qualify(
685 "SELECT name FROM users WHERE email = 'test@test.com'",
686 &schema
687 ),
688 "SELECT users.name FROM users WHERE users.email = 'test@test.com'"
689 );
690 }
691
692 #[test]
693 fn test_qualify_order_by() {
694 let schema = make_schema();
695 assert_eq!(
696 qualify("SELECT name FROM users ORDER BY email", &schema),
697 "SELECT users.name FROM users ORDER BY users.email"
698 );
699 }
700
701 #[test]
702 fn test_qualify_group_by_having() {
703 let schema = make_schema();
704 assert_eq!(
705 qualify(
706 "SELECT status, COUNT(*) FROM orders GROUP BY status HAVING COUNT(*) > 1",
707 &schema
708 ),
709 "SELECT orders.status, COUNT(*) FROM orders GROUP BY orders.status HAVING COUNT(*) > 1"
710 );
711 }
712
713 #[test]
714 fn test_expand_star_join() {
715 let schema = make_schema();
716 let result = qualify(
717 "SELECT * FROM users JOIN orders ON users.id = orders.user_id",
718 &schema,
719 );
720 assert_eq!(
721 result,
722 "SELECT id, name, email, id, user_id, amount, status FROM users INNER JOIN orders ON users.id = orders.user_id"
723 );
724 }
725
726 #[test]
727 fn test_cte_column_resolution() {
728 let schema = make_schema();
729 let result = qualify(
730 "WITH active AS (SELECT id, name FROM users) SELECT id, name FROM active",
731 &schema,
732 );
733 assert_eq!(
734 result,
735 "WITH active AS (SELECT users.id, users.name FROM users) SELECT active.id, active.name FROM active"
736 );
737 }
738
739 #[test]
740 fn test_derived_table_column_resolution() {
741 let schema = make_schema();
742 let result = qualify(
743 "SELECT id FROM (SELECT id, name FROM users) AS sub",
744 &schema,
745 );
746 assert_eq!(
747 result,
748 "SELECT sub.id FROM (SELECT users.id, users.name FROM users) AS sub"
749 );
750 }
751
752 #[test]
753 fn test_preserve_expression_aliases() {
754 let schema = make_schema();
755 assert_eq!(
756 qualify("SELECT name AS user_name FROM users", &schema),
757 "SELECT users.name AS user_name FROM users"
758 );
759 }
760
761 #[test]
762 fn test_qualify_join_on() {
763 let schema = make_schema();
764 assert_eq!(
767 qualify(
768 "SELECT name FROM users JOIN orders ON id = user_id",
769 &schema
770 ),
771 "SELECT users.name FROM users INNER JOIN orders ON id = orders.user_id"
772 );
773 }
774
775 #[test]
776 fn test_no_schema_columns_passthrough() {
777 let schema = make_schema();
779 assert_eq!(
780 qualify("SELECT x, y FROM unknown_table", &schema),
781 "SELECT x, y FROM unknown_table"
782 );
783 }
784}