1use std::borrow::Borrow;
2use std::collections::HashMap;
3use std::fmt;
4use std::hash::{Hash, Hasher};
5use std::ops::Deref;
6
7#[derive(Debug, Clone, Eq)]
10pub struct TableKey(String);
11
12impl TableKey {
13 #[inline]
15 pub fn new(name: impl AsRef<str>) -> Self {
16 TableKey(name.as_ref().to_lowercase())
17 }
18
19 #[inline]
21 pub fn as_str(&self) -> &str {
22 &self.0
23 }
24
25 #[inline]
27 pub fn into_inner(self) -> String {
28 self.0
29 }
30}
31
32impl PartialEq for TableKey {
33 fn eq(&self, other: &Self) -> bool {
34 self.0 == other.0
35 }
36}
37
38impl Hash for TableKey {
39 fn hash<H: Hasher>(&self, state: &mut H) {
40 self.0.hash(state);
41 }
42}
43
44impl Deref for TableKey {
45 type Target = str;
46
47 fn deref(&self) -> &Self::Target {
48 &self.0
49 }
50}
51
52impl AsRef<str> for TableKey {
53 fn as_ref(&self) -> &str {
54 &self.0
55 }
56}
57
58impl Borrow<str> for TableKey {
59 fn borrow(&self) -> &str {
60 &self.0
61 }
62}
63
64impl fmt::Display for TableKey {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 write!(f, "{}", self.0)
67 }
68}
69
70impl From<String> for TableKey {
71 fn from(s: String) -> Self {
72 TableKey::new(s)
73 }
74}
75
76impl From<&str> for TableKey {
77 fn from(s: &str) -> Self {
78 TableKey::new(s)
79 }
80}
81
82impl From<TableKey> for String {
83 fn from(key: TableKey) -> Self {
84 key.0
85 }
86}
87
88impl From<&TableKey> for TableKey {
89 fn from(key: &TableKey) -> Self {
90 key.clone()
91 }
92}
93
94impl From<&String> for TableKey {
95 fn from(s: &String) -> Self {
96 TableKey::new(s)
97 }
98}
99
100#[derive(Debug, Clone)]
102pub struct CombinedSchema {
103 pub table_schemas: HashMap<TableKey, (usize, vibesql_catalog::TableSchema)>,
107 pub total_columns: usize,
109}
110
111impl CombinedSchema {
112 pub fn from_table(table_name: String, schema: vibesql_catalog::TableSchema) -> Self {
116 let total_columns = schema.columns.len();
117 let mut table_schemas = HashMap::new();
118 table_schemas.insert(TableKey::new(table_name), (0, schema));
120 CombinedSchema { table_schemas, total_columns }
121 }
122
123 pub fn from_derived_table(
127 alias: String,
128 column_names: Vec<String>,
129 column_types: Vec<vibesql_types::DataType>,
130 ) -> Self {
131 let total_columns = column_names.len();
132
133 let columns: Vec<vibesql_catalog::ColumnSchema> = column_names
135 .into_iter()
136 .zip(column_types)
137 .map(|(name, data_type)| vibesql_catalog::ColumnSchema {
138 name,
139 data_type,
140 nullable: true, default_value: None, })
143 .collect();
144
145 let schema = vibesql_catalog::TableSchema::new(alias.clone(), columns);
146 let mut table_schemas = HashMap::new();
147 table_schemas.insert(TableKey::new(alias), (0, schema));
149 CombinedSchema { table_schemas, total_columns }
150 }
151
152 pub fn combine(
156 left: CombinedSchema,
157 right_table: impl Into<TableKey>,
158 right_schema: vibesql_catalog::TableSchema,
159 ) -> Self {
160 let mut table_schemas = left.table_schemas;
161 let left_total = left.total_columns;
162 let right_columns = right_schema.columns.len();
163 table_schemas.insert(right_table.into(), (left_total, right_schema));
165 CombinedSchema { table_schemas, total_columns: left_total + right_columns }
166 }
167
168 pub fn get_column_index(&self, table: Option<&str>, column: &str) -> Option<usize> {
171 if let Some(table_name) = table {
172 let key = TableKey::new(table_name);
175 if let Some((start_index, schema)) = self.table_schemas.get(&key) {
176 return schema.get_column_index(column).map(|idx| start_index + idx);
177 }
178 None
179 } else {
180 let mut best_match: Option<usize> = None;
185 for (start_index, schema) in self.table_schemas.values() {
186 if let Some(idx) = schema.get_column_index(column) {
187 let absolute_idx = start_index + idx;
188 match best_match {
189 None => best_match = Some(absolute_idx),
190 Some(current_best) if absolute_idx < current_best => {
191 best_match = Some(absolute_idx);
192 }
193 _ => {}
194 }
195 }
196 }
197 best_match
198 }
199 }
200
201 pub fn get_table(&self, table_name: &str) -> Option<&(usize, vibesql_catalog::TableSchema)> {
203 self.table_schemas.get(&TableKey::new(table_name))
204 }
205
206 pub fn contains_table(&self, table_name: &str) -> bool {
208 self.table_schemas.contains_key(&TableKey::new(table_name))
209 }
210
211 pub fn table_names(&self) -> Vec<String> {
213 self.table_schemas.keys().map(|k| k.to_string()).collect()
214 }
215
216 pub fn insert_table(
218 &mut self,
219 name: impl Into<TableKey>,
220 start_index: usize,
221 schema: vibesql_catalog::TableSchema,
222 ) {
223 self.table_schemas.insert(name.into(), (start_index, schema));
224 }
225}
226
227#[derive(Debug)]
232pub struct SchemaBuilder {
233 table_schemas: HashMap<TableKey, (usize, vibesql_catalog::TableSchema)>,
234 column_offset: usize,
235}
236
237impl SchemaBuilder {
238 pub fn new() -> Self {
240 SchemaBuilder { table_schemas: HashMap::new(), column_offset: 0 }
241 }
242
243 pub fn from_schema(schema: CombinedSchema) -> Self {
247 let column_offset = schema.total_columns;
248 SchemaBuilder { table_schemas: schema.table_schemas, column_offset }
250 }
251
252 pub fn add_table(&mut self, name: impl Into<TableKey>, schema: vibesql_catalog::TableSchema) -> &mut Self {
257 let num_columns = schema.columns.len();
258 self.table_schemas.insert(name.into(), (self.column_offset, schema));
260 self.column_offset += num_columns;
261 self
262 }
263
264 pub fn build(self) -> CombinedSchema {
268 CombinedSchema { table_schemas: self.table_schemas, total_columns: self.column_offset }
269 }
270}
271
272impl Default for SchemaBuilder {
273 fn default() -> Self {
274 Self::new()
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use vibesql_catalog::ColumnSchema;
282 use vibesql_types::DataType;
283
284 fn table_schema_with_columns(table_name: &str, columns: Vec<(&str, DataType)>) -> vibesql_catalog::TableSchema {
286 let cols: Vec<ColumnSchema> = columns
287 .into_iter()
288 .map(|(name, data_type)| ColumnSchema::new(name.to_string(), data_type, true))
289 .collect();
290 vibesql_catalog::TableSchema::new(table_name.to_string(), cols)
291 }
292
293 fn table_schema_with_column(table_name: &str, column_name: &str) -> vibesql_catalog::TableSchema {
295 table_schema_with_columns(table_name, vec![(column_name, DataType::Integer)])
296 }
297
298 #[test]
303 fn test_from_table_uppercase_insertion_case_insensitive_lookup() {
304 let schema = CombinedSchema::from_table(
306 "ITEM".to_string(),
307 table_schema_with_column("ITEM", "price"),
308 );
309
310 assert!(schema.get_column_index(Some("ITEM"), "price").is_some(), "ITEM should find price");
312 assert!(schema.get_column_index(Some("item"), "price").is_some(), "item should find price");
313 assert!(schema.get_column_index(Some("Item"), "price").is_some(), "Item should find price");
314 assert!(schema.get_column_index(Some("iTEM"), "price").is_some(), "iTEM should find price");
315 }
316
317 #[test]
318 fn test_from_table_lowercase_insertion_case_insensitive_lookup() {
319 let schema = CombinedSchema::from_table(
321 "item".to_string(),
322 table_schema_with_column("item", "price"),
323 );
324
325 assert!(schema.get_column_index(Some("ITEM"), "price").is_some());
327 assert!(schema.get_column_index(Some("item"), "price").is_some());
328 assert!(schema.get_column_index(Some("Item"), "price").is_some());
329 }
330
331 #[test]
332 fn test_from_table_mixedcase_insertion_case_insensitive_lookup() {
333 let schema = CombinedSchema::from_table(
335 "MyTable".to_string(),
336 table_schema_with_column("MyTable", "id"),
337 );
338
339 assert!(schema.get_column_index(Some("MYTABLE"), "id").is_some());
341 assert!(schema.get_column_index(Some("mytable"), "id").is_some());
342 assert!(schema.get_column_index(Some("MyTable"), "id").is_some());
343 assert!(schema.get_column_index(Some("myTable"), "id").is_some());
344 }
345
346 #[test]
351 fn test_from_derived_table_case_insensitive_alias() {
352 let schema = CombinedSchema::from_derived_table(
354 "SUBQ".to_string(),
355 vec!["col1".to_string(), "col2".to_string()],
356 vec![DataType::Integer, DataType::Varchar { max_length: None }],
357 );
358
359 assert!(schema.get_column_index(Some("SUBQ"), "col1").is_some());
361 assert!(schema.get_column_index(Some("subq"), "col1").is_some());
362 assert!(schema.get_column_index(Some("Subq"), "col1").is_some());
363 }
364
365 #[test]
370 fn test_combine_case_insensitive_both_tables() {
371 let left = CombinedSchema::from_table(
373 "ORDERS".to_string(),
374 table_schema_with_columns("ORDERS", vec![("order_id", DataType::Integer), ("customer_id", DataType::Integer)]),
375 );
376
377 let combined = CombinedSchema::combine(
379 left,
380 "Items".to_string(),
381 table_schema_with_columns("Items", vec![("item_id", DataType::Integer), ("price", DataType::DoublePrecision)]),
382 );
383
384 assert!(combined.get_column_index(Some("orders"), "order_id").is_some());
386 assert!(combined.get_column_index(Some("ORDERS"), "order_id").is_some());
387 assert!(combined.get_column_index(Some("Orders"), "customer_id").is_some());
388
389 assert!(combined.get_column_index(Some("items"), "item_id").is_some());
391 assert!(combined.get_column_index(Some("ITEMS"), "item_id").is_some());
392 assert!(combined.get_column_index(Some("Items"), "price").is_some());
393
394 assert_eq!(combined.get_column_index(Some("orders"), "order_id"), Some(0));
396 assert_eq!(combined.get_column_index(Some("orders"), "customer_id"), Some(1));
397 assert_eq!(combined.get_column_index(Some("items"), "item_id"), Some(2));
398 assert_eq!(combined.get_column_index(Some("items"), "price"), Some(3));
399 }
400
401 #[test]
402 fn test_combine_multiple_joins_case_insensitive() {
403 let orders = CombinedSchema::from_table(
405 "O".to_string(), table_schema_with_column("O", "order_id"),
407 );
408
409 let with_customers = CombinedSchema::combine(
410 orders,
411 "C".to_string(),
412 table_schema_with_column("C", "customer_id"),
413 );
414
415 let with_items = CombinedSchema::combine(
416 with_customers,
417 "I".to_string(),
418 table_schema_with_column("I", "item_id"),
419 );
420
421 assert!(with_items.get_column_index(Some("o"), "order_id").is_some());
423 assert!(with_items.get_column_index(Some("O"), "order_id").is_some());
424 assert!(with_items.get_column_index(Some("c"), "customer_id").is_some());
425 assert!(with_items.get_column_index(Some("C"), "customer_id").is_some());
426 assert!(with_items.get_column_index(Some("i"), "item_id").is_some());
427 assert!(with_items.get_column_index(Some("I"), "item_id").is_some());
428 }
429
430 #[test]
435 fn test_unqualified_column_lookup_no_ambiguity() {
436 let schema = CombinedSchema::from_table(
437 "USERS".to_string(),
438 table_schema_with_columns("USERS", vec![("id", DataType::Integer), ("name", DataType::Varchar { max_length: None })]),
439 );
440
441 assert!(schema.get_column_index(None, "id").is_some());
443 assert!(schema.get_column_index(None, "name").is_some());
444 assert!(schema.get_column_index(None, "missing").is_none());
445 }
446
447 #[test]
448 fn test_column_case_sensitive_with_fallback() {
449 let schema = CombinedSchema::from_table(
451 "users".to_string(),
452 table_schema_with_column("users", "UserName"),
453 );
454
455 assert!(schema.get_column_index(Some("users"), "UserName").is_some());
457 assert!(schema.get_column_index(Some("users"), "username").is_some());
459 assert!(schema.get_column_index(Some("users"), "USERNAME").is_some());
460 }
461
462 #[test]
466 fn test_tpcds_q6_case_insensitive_column_lookup_issue_4111() {
467 let schema = CombinedSchema::from_table(
469 "J".to_string(), table_schema_with_columns(
471 "item",
472 vec![
473 ("i_item_sk", DataType::Integer),
474 ("i_current_price", DataType::DoublePrecision), ("i_category", DataType::Varchar { max_length: None }),
476 ],
477 ),
478 );
479
480 assert!(
484 schema.get_column_index(Some("J"), "I_CURRENT_PRICE").is_some(),
485 "J.I_CURRENT_PRICE should find i_current_price via case-insensitive lookup"
486 );
487 assert!(
488 schema.get_column_index(Some("J"), "I_CATEGORY").is_some(),
489 "J.I_CATEGORY should find i_category via case-insensitive lookup"
490 );
491 assert!(
492 schema.get_column_index(Some("j"), "I_CURRENT_PRICE").is_some(),
493 "j.I_CURRENT_PRICE should find i_current_price"
494 );
495 assert!(
496 schema.get_column_index(Some("J"), "i_current_price").is_some(),
497 "J.i_current_price should find via exact match"
498 );
499 }
500
501 #[test]
502 fn test_column_distinct_cases_exact_match() {
503 let cols: Vec<vibesql_catalog::ColumnSchema> = vec![
506 vibesql_catalog::ColumnSchema::new("value".to_string(), DataType::Integer, true),
507 vibesql_catalog::ColumnSchema::new("VALUE".to_string(), DataType::Integer, true),
508 vibesql_catalog::ColumnSchema::new("Value".to_string(), DataType::Integer, true),
509 ];
510 let table_schema = vibesql_catalog::TableSchema::new("data".to_string(), cols);
511 let schema = CombinedSchema::from_table("data".to_string(), table_schema);
512
513 assert_eq!(schema.get_column_index(Some("data"), "value"), Some(0));
515 assert_eq!(schema.get_column_index(Some("data"), "VALUE"), Some(1));
516 assert_eq!(schema.get_column_index(Some("data"), "Value"), Some(2));
517 }
518
519 #[test]
524 fn test_schema_builder_add_table_case_insensitive() {
525 let mut builder = SchemaBuilder::new();
526
527 builder.add_table(
529 "ORDERS".to_string(),
530 table_schema_with_column("ORDERS", "order_id"),
531 );
532 builder.add_table(
533 "Items".to_string(),
534 table_schema_with_column("Items", "item_id"),
535 );
536
537 let schema = builder.build();
538
539 assert!(schema.get_column_index(Some("orders"), "order_id").is_some());
541 assert!(schema.get_column_index(Some("ORDERS"), "order_id").is_some());
542 assert!(schema.get_column_index(Some("items"), "item_id").is_some());
543 assert!(schema.get_column_index(Some("ITEMS"), "item_id").is_some());
544 }
545
546 #[test]
547 fn test_schema_builder_from_schema_preserves_case_insensitivity() {
548 let initial = CombinedSchema::from_table(
550 "PRODUCTS".to_string(),
551 table_schema_with_columns("PRODUCTS", vec![("id", DataType::Integer), ("name", DataType::Varchar { max_length: None })]),
552 );
553
554 assert!(initial.get_column_index(Some("products"), "id").is_some());
556
557 let mut builder = SchemaBuilder::from_schema(initial);
559 builder.add_table(
560 "Categories".to_string(),
561 table_schema_with_column("Categories", "cat_id"),
562 );
563
564 let final_schema = builder.build();
565
566 assert!(final_schema.get_column_index(Some("products"), "id").is_some());
568 assert!(final_schema.get_column_index(Some("PRODUCTS"), "id").is_some());
569 assert!(final_schema.get_column_index(Some("Products"), "name").is_some());
570
571 assert!(final_schema.get_column_index(Some("categories"), "cat_id").is_some());
573 assert!(final_schema.get_column_index(Some("CATEGORIES"), "cat_id").is_some());
574 }
575
576 #[test]
577 fn test_schema_builder_from_schema_multiple_tables() {
578 let orders = CombinedSchema::from_table(
580 "Orders".to_string(),
581 table_schema_with_column("Orders", "order_id"),
582 );
583 let combined = CombinedSchema::combine(
584 orders,
585 "Items".to_string(),
586 table_schema_with_column("Items", "item_id"),
587 );
588
589 let mut builder = SchemaBuilder::from_schema(combined);
591 builder.add_table(
592 "CUSTOMERS".to_string(),
593 table_schema_with_column("CUSTOMERS", "cust_id"),
594 );
595
596 let final_schema = builder.build();
597
598 assert!(final_schema.get_column_index(Some("orders"), "order_id").is_some());
600 assert!(final_schema.get_column_index(Some("ORDERS"), "order_id").is_some());
601 assert!(final_schema.get_column_index(Some("items"), "item_id").is_some());
602 assert!(final_schema.get_column_index(Some("ITEMS"), "item_id").is_some());
603 assert!(final_schema.get_column_index(Some("customers"), "cust_id").is_some());
604 assert!(final_schema.get_column_index(Some("CUSTOMERS"), "cust_id").is_some());
605
606 assert_eq!(final_schema.get_column_index(Some("orders"), "order_id"), Some(0));
608 assert_eq!(final_schema.get_column_index(Some("items"), "item_id"), Some(1));
609 assert_eq!(final_schema.get_column_index(Some("customers"), "cust_id"), Some(2));
610 }
611
612 #[test]
617 fn test_issue_3633_correlated_subquery_alias_case() {
618 let schema = CombinedSchema::from_table(
624 "J".to_string(), table_schema_with_columns("items", vec![("price", DataType::DoublePrecision), ("quantity", DataType::Integer)]),
626 );
627
628 assert!(schema.get_column_index(Some("J"), "price").is_some(),
631 "Uppercase J should find price (parser case)");
632 assert!(schema.get_column_index(Some("j"), "price").is_some(),
633 "Lowercase j should find price (normalized case)");
634 }
635
636 #[test]
637 fn test_issue_3633_multi_table_join_with_aliases() {
638 let orders = CombinedSchema::from_table(
640 "O".to_string(),
641 table_schema_with_columns("orders", vec![("id", DataType::Integer), ("date", DataType::Date)]),
642 );
643
644 let combined = CombinedSchema::combine(
645 orders,
646 "I".to_string(),
647 table_schema_with_columns("items", vec![("order_id", DataType::Integer), ("amount", DataType::DoublePrecision)]),
648 );
649
650 assert_eq!(combined.get_column_index(Some("O"), "id"), Some(0));
653 assert_eq!(combined.get_column_index(Some("o"), "id"), Some(0));
654 assert_eq!(combined.get_column_index(Some("O"), "date"), Some(1));
655 assert_eq!(combined.get_column_index(Some("I"), "order_id"), Some(2));
656 assert_eq!(combined.get_column_index(Some("i"), "order_id"), Some(2));
657 assert_eq!(combined.get_column_index(Some("I"), "amount"), Some(3));
658 }
659
660 #[test]
665 fn test_nonexistent_table_returns_none() {
666 let schema = CombinedSchema::from_table(
667 "users".to_string(),
668 table_schema_with_column("users", "id"),
669 );
670
671 assert!(schema.get_column_index(Some("nonexistent"), "id").is_none());
672 assert!(schema.get_column_index(Some("NONEXISTENT"), "id").is_none());
673 }
674
675 #[test]
676 fn test_nonexistent_column_returns_none() {
677 let schema = CombinedSchema::from_table(
678 "users".to_string(),
679 table_schema_with_column("users", "id"),
680 );
681
682 assert!(schema.get_column_index(Some("users"), "nonexistent").is_none());
683 assert!(schema.get_column_index(Some("USERS"), "nonexistent").is_none());
684 }
685
686 #[test]
687 fn test_empty_table_name() {
688 let schema = CombinedSchema::from_table(
689 "".to_string(),
690 table_schema_with_column("", "id"),
691 );
692
693 assert!(schema.get_column_index(Some(""), "id").is_some());
695 }
696
697 #[test]
698 fn test_total_columns_tracking() {
699 let mut builder = SchemaBuilder::new();
700 builder.add_table(
701 "t1".to_string(),
702 table_schema_with_columns("t1", vec![("a", DataType::Integer), ("b", DataType::Integer)]),
703 );
704 builder.add_table(
705 "t2".to_string(),
706 table_schema_with_columns("t2", vec![("c", DataType::Integer)]),
707 );
708
709 let schema = builder.build();
710 assert_eq!(schema.total_columns, 3);
711 }
712}