1use std::collections::HashMap;
11
12use crate::ast::*;
13use crate::dialects::Dialect;
14use crate::schema::{Schema, normalize_identifier};
15
16pub fn qualify_columns<S: Schema>(statement: Statement, schema: &S) -> Statement {
21 let dialect = schema.dialect();
22 match statement {
23 Statement::Select(sel) => {
24 let qualified = qualify_select(sel, schema, dialect, &HashMap::new());
25 Statement::Select(qualified)
26 }
27 Statement::SetOperation(mut set_op) => {
28 set_op.left = Box::new(qualify_columns(*set_op.left, schema));
29 set_op.right = Box::new(qualify_columns(*set_op.right, schema));
30 Statement::SetOperation(set_op)
31 }
32 other => other,
33 }
34}
35
36#[derive(Debug, Clone)]
38struct SourceColumns {
39 columns: Vec<String>,
41}
42
43fn resolve_source_columns<S: Schema>(
45 sel: &SelectStatement,
46 schema: &S,
47 dialect: Dialect,
48 cte_columns: &HashMap<String, Vec<String>>,
49) -> HashMap<String, SourceColumns> {
50 let mut source_map: HashMap<String, SourceColumns> = HashMap::new();
51
52 if let Some(from) = &sel.from {
54 collect_source_columns(&from.source, schema, dialect, cte_columns, &mut source_map);
55 }
56
57 for join in &sel.joins {
59 collect_source_columns(&join.table, schema, dialect, cte_columns, &mut source_map);
60 }
61
62 source_map
63}
64
65fn collect_source_columns<S: Schema>(
67 source: &TableSource,
68 schema: &S,
69 dialect: Dialect,
70 cte_columns: &HashMap<String, Vec<String>>,
71 source_map: &mut HashMap<String, SourceColumns>,
72) {
73 match source {
74 TableSource::Table(table_ref) => {
75 let key = table_ref
76 .alias
77 .as_deref()
78 .unwrap_or(&table_ref.name)
79 .to_string();
80 let norm_key = normalize_identifier(&key, dialect);
81
82 let norm_name = normalize_identifier(&table_ref.name, dialect);
84 if let Some(cols) = cte_columns.get(&norm_name) {
85 source_map.insert(
86 norm_key,
87 SourceColumns {
88 columns: cols.clone(),
89 },
90 );
91 return;
92 }
93
94 let path = build_table_path(table_ref, dialect);
96 let path_refs: Vec<&str> = path.iter().map(|s| s.as_str()).collect();
97
98 if let Ok(cols) = schema.column_names(&path_refs) {
99 source_map.insert(norm_key, SourceColumns { columns: cols });
100 }
101 }
102 TableSource::Subquery { query, alias } => {
103 if let Some(alias) = alias {
104 let norm_alias = normalize_identifier(alias, dialect);
105 let cols = extract_output_columns(query, schema, dialect, cte_columns);
106 source_map.insert(norm_alias, SourceColumns { columns: cols });
107 }
108 }
109 TableSource::Lateral { source: inner } => {
110 collect_source_columns(inner, schema, dialect, cte_columns, source_map);
111 }
112 TableSource::Pivot { source, alias, .. } | TableSource::Unpivot { source, alias, .. } => {
113 collect_source_columns(source, schema, dialect, cte_columns, source_map);
114 if let Some(alias) = alias {
115 let norm_alias = normalize_identifier(alias, dialect);
116 source_map.insert(norm_alias, SourceColumns { columns: vec![] });
117 }
118 }
119 TableSource::Unnest { alias, .. } => {
120 if let Some(alias) = alias {
121 let norm_alias = normalize_identifier(alias, dialect);
122 source_map.insert(norm_alias, SourceColumns { columns: vec![] });
124 }
125 }
126 TableSource::TableFunction { alias, .. } => {
127 if let Some(alias) = alias {
128 let norm_alias = normalize_identifier(alias, dialect);
129 source_map.insert(norm_alias, SourceColumns { columns: vec![] });
130 }
131 }
132 }
133}
134
135fn build_table_path(table_ref: &TableRef, dialect: Dialect) -> Vec<String> {
137 let mut path = Vec::new();
138 if let Some(cat) = &table_ref.catalog {
139 path.push(normalize_identifier(cat, dialect));
140 }
141 if let Some(sch) = &table_ref.schema {
142 path.push(normalize_identifier(sch, dialect));
143 }
144 path.push(normalize_identifier(&table_ref.name, dialect));
145 path
146}
147
148fn extract_output_columns<S: Schema>(
150 stmt: &Statement,
151 schema: &S,
152 dialect: Dialect,
153 cte_columns: &HashMap<String, Vec<String>>,
154) -> Vec<String> {
155 match stmt {
156 Statement::Select(sel) => {
157 let inner_sources = resolve_source_columns(sel, schema, dialect, cte_columns);
158 let mut cols = Vec::new();
159 for item in &sel.columns {
160 match item {
161 SelectItem::Wildcard => {
162 for_each_source_ordered(sel, dialect, &inner_sources, |sc| {
164 cols.extend(sc.columns.iter().cloned());
165 });
166 }
167 SelectItem::QualifiedWildcard { table } => {
168 let norm_table = normalize_identifier(table, dialect);
169 if let Some(sc) = inner_sources.get(&norm_table) {
170 cols.extend(sc.columns.iter().cloned());
171 }
172 }
173 SelectItem::Expr { alias, expr } => {
174 if let Some(alias) = alias {
175 cols.push(alias.clone());
176 } else {
177 cols.push(expr_output_name(expr));
178 }
179 }
180 }
181 }
182 cols
183 }
184 Statement::SetOperation(set_op) => {
185 extract_output_columns(&set_op.left, schema, dialect, cte_columns)
187 }
188 _ => vec![],
189 }
190}
191
192fn expr_output_name(expr: &Expr) -> String {
194 match expr {
195 Expr::Column { name, .. } => name.clone(),
196 Expr::Function { name, .. } => name.clone(),
197 Expr::TypedFunction { .. } => "_col".to_string(),
198 _ => "_col".to_string(),
199 }
200}
201
202fn for_each_source_ordered<F>(
204 sel: &SelectStatement,
205 dialect: Dialect,
206 source_map: &HashMap<String, SourceColumns>,
207 mut callback: F,
208) where
209 F: FnMut(&SourceColumns),
210{
211 if let Some(from) = &sel.from {
213 let key = source_key_for(&from.source, dialect);
214 if let Some(sc) = source_map.get(&key) {
215 callback(sc);
216 }
217 }
218 for join in &sel.joins {
220 let key = source_key_for(&join.table, dialect);
221 if let Some(sc) = source_map.get(&key) {
222 callback(sc);
223 }
224 }
225}
226
227fn source_key_for(source: &TableSource, dialect: Dialect) -> String {
229 match source {
230 TableSource::Table(tr) => {
231 let name = tr.alias.as_deref().unwrap_or(&tr.name);
232 normalize_identifier(name, dialect)
233 }
234 TableSource::Subquery { alias, .. } => alias
235 .as_deref()
236 .map(|a| normalize_identifier(a, dialect))
237 .unwrap_or_default(),
238 TableSource::Lateral { source } => source_key_for(source, dialect),
239 TableSource::Pivot { source, alias, .. } | TableSource::Unpivot { source, alias, .. } => {
240 if let Some(a) = alias {
241 normalize_identifier(a, dialect)
242 } else {
243 source_key_for(source, dialect)
244 }
245 }
246 TableSource::Unnest { alias, .. } | TableSource::TableFunction { alias, .. } => alias
247 .as_deref()
248 .map(|a| normalize_identifier(a, dialect))
249 .unwrap_or_default(),
250 }
251}
252
253fn qualify_select<S: Schema>(
255 mut sel: SelectStatement,
256 schema: &S,
257 dialect: Dialect,
258 outer_cte_columns: &HashMap<String, Vec<String>>,
259) -> SelectStatement {
260 let mut cte_columns = outer_cte_columns.clone();
262 for cte in &sel.ctes {
263 let cols = if !cte.columns.is_empty() {
264 cte.columns.clone()
266 } else {
267 extract_output_columns(&cte.query, schema, dialect, &cte_columns)
268 };
269 let norm_name = normalize_identifier(&cte.name, dialect);
270 cte_columns.insert(norm_name, cols);
271 }
272
273 sel.ctes = sel
275 .ctes
276 .into_iter()
277 .map(|mut cte| {
278 cte.query = Box::new(qualify_columns(*cte.query, schema));
279 cte
280 })
281 .collect();
282
283 if let Some(ref mut from) = sel.from {
285 qualify_table_source(&mut from.source, schema, dialect, &cte_columns);
286 }
287 for join in &mut sel.joins {
288 qualify_table_source(&mut join.table, schema, dialect, &cte_columns);
289 }
290
291 let source_map = resolve_source_columns(&sel, schema, dialect, &cte_columns);
293
294 let mut new_columns = Vec::new();
296 let old_columns = std::mem::take(&mut sel.columns);
297 for item in old_columns {
298 match item {
299 SelectItem::Wildcard => {
300 for_each_source_ordered(&sel, dialect, &source_map, |sc| {
302 for col_name in &sc.columns {
303 new_columns.push(SelectItem::Expr {
304 expr: Expr::Column {
305 table: None,
306 name: col_name.clone(),
307 quote_style: QuoteStyle::None,
308 table_quote_style: QuoteStyle::None,
309 },
310 alias: None,
311 });
312 }
313 });
314 }
315 SelectItem::QualifiedWildcard { table } => {
316 let norm_table = normalize_identifier(&table, dialect);
317 if let Some(sc) = source_map.get(&norm_table) {
318 for col_name in &sc.columns {
319 new_columns.push(SelectItem::Expr {
320 expr: Expr::Column {
321 table: Some(table.clone()),
322 name: col_name.clone(),
323 quote_style: QuoteStyle::None,
324 table_quote_style: QuoteStyle::None,
325 },
326 alias: None,
327 });
328 }
329 } else {
330 new_columns.push(SelectItem::QualifiedWildcard { table });
332 }
333 }
334 SelectItem::Expr { expr, alias } => {
335 let qualified_expr = qualify_expr(expr, &source_map, schema, dialect, &cte_columns);
336 new_columns.push(SelectItem::Expr {
337 expr: qualified_expr,
338 alias,
339 });
340 }
341 }
342 }
343 sel.columns = new_columns;
344
345 if let Some(wh) = sel.where_clause {
347 sel.where_clause = Some(qualify_expr(wh, &source_map, schema, dialect, &cte_columns));
348 }
349 sel.group_by = sel
350 .group_by
351 .into_iter()
352 .map(|e| qualify_expr(e, &source_map, schema, dialect, &cte_columns))
353 .collect();
354 if let Some(having) = sel.having {
355 sel.having = Some(qualify_expr(
356 having,
357 &source_map,
358 schema,
359 dialect,
360 &cte_columns,
361 ));
362 }
363 sel.order_by = sel
364 .order_by
365 .into_iter()
366 .map(|mut item| {
367 item.expr = qualify_expr(item.expr, &source_map, schema, dialect, &cte_columns);
368 item
369 })
370 .collect();
371 if let Some(qualify) = sel.qualify {
372 sel.qualify = Some(qualify_expr(
373 qualify,
374 &source_map,
375 schema,
376 dialect,
377 &cte_columns,
378 ));
379 }
380
381 for join in &mut sel.joins {
383 if let Some(on) = join.on.take() {
384 join.on = Some(qualify_expr(on, &source_map, schema, dialect, &cte_columns));
385 }
386 }
387
388 sel
389}
390
391fn qualify_table_source<S: Schema>(
393 source: &mut TableSource,
394 schema: &S,
395 dialect: Dialect,
396 cte_columns: &HashMap<String, Vec<String>>,
397) {
398 match source {
399 TableSource::Subquery { query, .. } => {
400 *query = Box::new(qualify_columns_inner(
401 *query.clone(),
402 schema,
403 dialect,
404 cte_columns,
405 ));
406 }
407 TableSource::Lateral { source: inner } => {
408 qualify_table_source(inner, schema, dialect, cte_columns);
409 }
410 TableSource::Pivot { source, .. } | TableSource::Unpivot { source, .. } => {
411 qualify_table_source(source, schema, dialect, cte_columns);
412 }
413 _ => {}
414 }
415}
416
417fn qualify_columns_inner<S: Schema>(
419 statement: Statement,
420 schema: &S,
421 dialect: Dialect,
422 cte_columns: &HashMap<String, Vec<String>>,
423) -> Statement {
424 match statement {
425 Statement::Select(sel) => {
426 Statement::Select(qualify_select(sel, schema, dialect, cte_columns))
427 }
428 Statement::SetOperation(mut set_op) => {
429 set_op.left = Box::new(qualify_columns_inner(
430 *set_op.left,
431 schema,
432 dialect,
433 cte_columns,
434 ));
435 set_op.right = Box::new(qualify_columns_inner(
436 *set_op.right,
437 schema,
438 dialect,
439 cte_columns,
440 ));
441 Statement::SetOperation(set_op)
442 }
443 other => other,
444 }
445}
446
447fn qualify_expr<S: Schema>(
450 expr: Expr,
451 source_map: &HashMap<String, SourceColumns>,
452 schema: &S,
453 dialect: Dialect,
454 cte_columns: &HashMap<String, Vec<String>>,
455) -> Expr {
456 expr.transform(&|e| match e {
457 Expr::Column {
458 table: None,
459 name,
460 quote_style,
461 table_quote_style,
462 } => {
463 let norm_name = normalize_identifier(&name, dialect);
464 let resolved_source = resolve_column(&norm_name, source_map);
466 if let Some(source_name) = resolved_source {
467 Expr::Column {
468 table: Some(source_name),
469 name,
470 quote_style,
471 table_quote_style,
472 }
473 } else {
474 Expr::Column {
477 table: None,
478 name,
479 quote_style,
480 table_quote_style,
481 }
482 }
483 }
484 Expr::InSubquery {
486 expr,
487 subquery,
488 negated,
489 } => Expr::InSubquery {
490 expr,
491 subquery: Box::new(qualify_columns_inner(
492 *subquery,
493 schema,
494 dialect,
495 cte_columns,
496 )),
497 negated,
498 },
499 Expr::Subquery(stmt) => Expr::Subquery(Box::new(qualify_columns_inner(
500 *stmt,
501 schema,
502 dialect,
503 cte_columns,
504 ))),
505 Expr::Exists { subquery, negated } => Expr::Exists {
506 subquery: Box::new(qualify_columns_inner(
507 *subquery,
508 schema,
509 dialect,
510 cte_columns,
511 )),
512 negated,
513 },
514 other => other,
515 })
516}
517
518fn resolve_column(
522 norm_col_name: &str,
523 source_map: &HashMap<String, SourceColumns>,
524) -> Option<String> {
525 let mut matches: Vec<&str> = Vec::new();
526 for (source_name, sc) in source_map {
527 if sc
528 .columns
529 .iter()
530 .any(|c| c.eq_ignore_ascii_case(norm_col_name))
531 {
532 matches.push(source_name);
533 }
534 }
535 if matches.len() == 1 {
536 Some(matches[0].to_string())
537 } else {
538 None
539 }
540}
541
542#[cfg(test)]
547mod tests {
548 use super::*;
549 use crate::generator::generate;
550 use crate::parser::parse;
551 use crate::schema::MappingSchema;
552
553 fn make_schema() -> MappingSchema {
554 let mut schema = MappingSchema::new(Dialect::Ansi);
555 schema
556 .add_table(
557 &["users"],
558 vec![
559 ("id".to_string(), DataType::Int),
560 ("name".to_string(), DataType::Varchar(Some(255))),
561 ("email".to_string(), DataType::Text),
562 ],
563 )
564 .unwrap();
565 schema
566 .add_table(
567 &["orders"],
568 vec![
569 ("id".to_string(), DataType::Int),
570 ("user_id".to_string(), DataType::Int),
571 (
572 "amount".to_string(),
573 DataType::Decimal {
574 precision: Some(10),
575 scale: Some(2),
576 },
577 ),
578 ("status".to_string(), DataType::Varchar(Some(50))),
579 ],
580 )
581 .unwrap();
582 schema
583 .add_table(
584 &["products"],
585 vec![
586 ("id".to_string(), DataType::Int),
587 ("name".to_string(), DataType::Varchar(Some(255))),
588 (
589 "price".to_string(),
590 DataType::Decimal {
591 precision: Some(10),
592 scale: Some(2),
593 },
594 ),
595 ],
596 )
597 .unwrap();
598 schema
599 }
600
601 fn qualify(sql: &str, schema: &MappingSchema) -> String {
602 let stmt = parse(sql, Dialect::Ansi).unwrap();
603 let qualified = qualify_columns(stmt, schema);
604 generate(&qualified, Dialect::Ansi)
605 }
606
607 #[test]
608 fn test_expand_star() {
609 let schema = make_schema();
610 assert_eq!(
611 qualify("SELECT * FROM users", &schema),
612 "SELECT id, name, email FROM users"
613 );
614 }
615
616 #[test]
617 fn test_expand_qualified_wildcard() {
618 let schema = make_schema();
619 assert_eq!(
620 qualify("SELECT users.* FROM users", &schema),
621 "SELECT users.id, users.name, users.email FROM users"
622 );
623 }
624
625 #[test]
626 fn test_expand_star_with_alias() {
627 let schema = make_schema();
628 assert_eq!(
629 qualify("SELECT * FROM users AS u", &schema),
630 "SELECT id, name, email FROM users AS u"
631 );
632 }
633
634 #[test]
635 fn test_expand_qualified_wildcard_alias() {
636 let schema = make_schema();
637 assert_eq!(
638 qualify("SELECT u.* FROM users AS u", &schema),
639 "SELECT u.id, u.name, u.email FROM users AS u"
640 );
641 }
642
643 #[test]
644 fn test_qualify_unqualified_single_table() {
645 let schema = make_schema();
646 assert_eq!(
647 qualify("SELECT id, name FROM users", &schema),
648 "SELECT users.id, users.name FROM users"
649 );
650 }
651
652 #[test]
653 fn test_qualify_unqualified_single_table_alias() {
654 let schema = make_schema();
655 assert_eq!(
656 qualify("SELECT id, name FROM users AS u", &schema),
657 "SELECT u.id, u.name FROM users AS u"
658 );
659 }
660
661 #[test]
662 fn test_qualify_already_qualified() {
663 let schema = make_schema();
664 assert_eq!(
665 qualify("SELECT users.id, users.name FROM users", &schema),
666 "SELECT users.id, users.name FROM users"
667 );
668 }
669
670 #[test]
671 fn test_qualify_join_unambiguous() {
672 let schema = make_schema();
673 assert_eq!(
674 qualify(
675 "SELECT name, amount FROM users JOIN orders ON users.id = orders.user_id",
676 &schema
677 ),
678 "SELECT users.name, orders.amount FROM users INNER JOIN orders ON users.id = orders.user_id"
679 );
680 }
681
682 #[test]
683 fn test_qualify_join_ambiguous_left_unqualified() {
684 let schema = make_schema();
686 let result = qualify(
687 "SELECT id FROM users JOIN orders ON users.id = orders.user_id",
688 &schema,
689 );
690 assert_eq!(
692 result,
693 "SELECT id FROM users INNER JOIN orders ON users.id = orders.user_id"
694 );
695 }
696
697 #[test]
698 fn test_qualify_where_clause() {
699 let schema = make_schema();
700 assert_eq!(
701 qualify(
702 "SELECT name FROM users WHERE email = 'test@test.com'",
703 &schema
704 ),
705 "SELECT users.name FROM users WHERE users.email = 'test@test.com'"
706 );
707 }
708
709 #[test]
710 fn test_qualify_order_by() {
711 let schema = make_schema();
712 assert_eq!(
713 qualify("SELECT name FROM users ORDER BY email", &schema),
714 "SELECT users.name FROM users ORDER BY users.email"
715 );
716 }
717
718 #[test]
719 fn test_qualify_group_by_having() {
720 let schema = make_schema();
721 assert_eq!(
722 qualify(
723 "SELECT status, COUNT(*) FROM orders GROUP BY status HAVING COUNT(*) > 1",
724 &schema
725 ),
726 "SELECT orders.status, COUNT(*) FROM orders GROUP BY orders.status HAVING COUNT(*) > 1"
727 );
728 }
729
730 #[test]
731 fn test_expand_star_join() {
732 let schema = make_schema();
733 let result = qualify(
734 "SELECT * FROM users JOIN orders ON users.id = orders.user_id",
735 &schema,
736 );
737 assert_eq!(
738 result,
739 "SELECT id, name, email, id, user_id, amount, status FROM users INNER JOIN orders ON users.id = orders.user_id"
740 );
741 }
742
743 #[test]
744 fn test_cte_column_resolution() {
745 let schema = make_schema();
746 let result = qualify(
747 "WITH active AS (SELECT id, name FROM users) SELECT id, name FROM active",
748 &schema,
749 );
750 assert_eq!(
751 result,
752 "WITH active AS (SELECT users.id, users.name FROM users) SELECT active.id, active.name FROM active"
753 );
754 }
755
756 #[test]
757 fn test_derived_table_column_resolution() {
758 let schema = make_schema();
759 let result = qualify(
760 "SELECT id FROM (SELECT id, name FROM users) AS sub",
761 &schema,
762 );
763 assert_eq!(
764 result,
765 "SELECT sub.id FROM (SELECT users.id, users.name FROM users) AS sub"
766 );
767 }
768
769 #[test]
770 fn test_preserve_expression_aliases() {
771 let schema = make_schema();
772 assert_eq!(
773 qualify("SELECT name AS user_name FROM users", &schema),
774 "SELECT users.name AS user_name FROM users"
775 );
776 }
777
778 #[test]
779 fn test_qualify_join_on() {
780 let schema = make_schema();
781 assert_eq!(
784 qualify(
785 "SELECT name FROM users JOIN orders ON id = user_id",
786 &schema
787 ),
788 "SELECT users.name FROM users INNER JOIN orders ON id = orders.user_id"
789 );
790 }
791
792 #[test]
793 fn test_no_schema_columns_passthrough() {
794 let schema = make_schema();
796 assert_eq!(
797 qualify("SELECT x, y FROM unknown_table", &schema),
798 "SELECT x, y FROM unknown_table"
799 );
800 }
801}