1use std::collections::HashMap;
36
37use crate::ast::DataType;
38use crate::dialects::Dialect;
39use crate::errors::SqlglotError;
40
41#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum SchemaError {
44 TableNotFound(String),
46 ColumnNotFound { table: String, column: String },
48 DuplicateTable(String),
50}
51
52impl std::fmt::Display for SchemaError {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 match self {
55 SchemaError::TableNotFound(t) => write!(f, "Table not found: {t}"),
56 SchemaError::ColumnNotFound { table, column } => {
57 write!(f, "Column '{column}' not found in table '{table}'")
58 }
59 SchemaError::DuplicateTable(t) => write!(f, "Table already exists: {t}"),
60 }
61 }
62}
63
64impl std::error::Error for SchemaError {}
65
66impl From<SchemaError> for SqlglotError {
67 fn from(e: SchemaError) -> Self {
68 SqlglotError::Internal(e.to_string())
69 }
70}
71
72pub trait Schema {
77 fn add_table(
87 &mut self,
88 table_path: &[&str],
89 columns: Vec<(String, DataType)>,
90 ) -> Result<(), SchemaError>;
91
92 fn column_names(&self, table_path: &[&str]) -> Result<Vec<String>, SchemaError>;
98
99 fn get_column_type(
105 &self,
106 table_path: &[&str],
107 column: &str,
108 ) -> Result<DataType, SchemaError>;
109
110 fn has_column(&self, table_path: &[&str], column: &str) -> bool;
112
113 fn dialect(&self) -> Dialect;
115}
116
117#[derive(Debug, Clone, PartialEq)]
119struct ColumnInfo {
120 columns: Vec<(String, DataType)>,
122 index: HashMap<String, usize>,
124}
125
126impl ColumnInfo {
127 fn new(columns: Vec<(String, DataType)>, dialect: Dialect) -> Self {
128 let index = columns
129 .iter()
130 .enumerate()
131 .map(|(i, (name, _))| (normalize_identifier(name, dialect), i))
132 .collect();
133 Self { columns, index }
134 }
135
136 fn column_names(&self) -> Vec<String> {
137 self.columns.iter().map(|(n, _)| n.clone()).collect()
138 }
139
140 fn get_type(&self, column: &str, dialect: Dialect) -> Option<&DataType> {
141 let key = normalize_identifier(column, dialect);
142 self.index.get(&key).map(|&i| &self.columns[i].1)
143 }
144
145 fn has_column(&self, column: &str, dialect: Dialect) -> bool {
146 let key = normalize_identifier(column, dialect);
147 self.index.contains_key(&key)
148 }
149}
150
151#[derive(Debug, Clone)]
161pub struct MappingSchema {
162 dialect: Dialect,
163 tables: HashMap<String, HashMap<String, HashMap<String, ColumnInfo>>>,
165 udf_types: HashMap<String, DataType>,
167}
168
169impl MappingSchema {
170 #[must_use]
172 pub fn new(dialect: Dialect) -> Self {
173 Self {
174 dialect,
175 tables: HashMap::new(),
176 udf_types: HashMap::new(),
177 }
178 }
179
180 pub fn replace_table(
182 &mut self,
183 table_path: &[&str],
184 columns: Vec<(String, DataType)>,
185 ) -> Result<(), SchemaError> {
186 let (catalog, database, table) = self.resolve_path(table_path)?;
187 let info = ColumnInfo::new(columns, self.dialect);
188 self.tables
189 .entry(catalog)
190 .or_default()
191 .entry(database)
192 .or_default()
193 .insert(table, info);
194 Ok(())
195 }
196
197 pub fn remove_table(&mut self, table_path: &[&str]) -> Result<bool, SchemaError> {
199 let (catalog, database, table) = self.resolve_path(table_path)?;
200 let removed = self
201 .tables
202 .get_mut(&catalog)
203 .and_then(|dbs| dbs.get_mut(&database))
204 .map(|tbls| tbls.remove(&table).is_some())
205 .unwrap_or(false);
206 Ok(removed)
207 }
208
209 pub fn add_udf(&mut self, name: &str, return_type: DataType) {
211 let key = normalize_identifier(name, self.dialect);
212 self.udf_types.insert(key, return_type);
213 }
214
215 #[must_use]
217 pub fn get_udf_type(&self, name: &str) -> Option<&DataType> {
218 let key = normalize_identifier(name, self.dialect);
219 self.udf_types.get(&key)
220 }
221
222 #[must_use]
224 pub fn table_names(&self) -> Vec<(String, String, String)> {
225 let mut result = Vec::new();
226 for (catalog, dbs) in &self.tables {
227 for (database, tbls) in dbs {
228 for table in tbls.keys() {
229 result.push((catalog.clone(), database.clone(), table.clone()));
230 }
231 }
232 }
233 result
234 }
235
236 fn find_table(&self, table_path: &[&str]) -> Option<&ColumnInfo> {
239 let (catalog, database, table) = match self.resolve_path(table_path) {
240 Ok(parts) => parts,
241 Err(_) => return None,
242 };
243
244 if let Some(info) = self
246 .tables
247 .get(&catalog)
248 .and_then(|dbs| dbs.get(&database))
249 .and_then(|tbls| tbls.get(&table))
250 {
251 return Some(info);
252 }
253
254 if table_path.len() == 1 {
256 let norm_name = normalize_identifier(table_path[0], self.dialect);
257 for dbs in self.tables.values() {
258 for tbls in dbs.values() {
259 if let Some(info) = tbls.get(&norm_name) {
260 return Some(info);
261 }
262 }
263 }
264 }
265
266 if table_path.len() == 2 {
268 let norm_db = normalize_identifier(table_path[0], self.dialect);
269 let norm_tbl = normalize_identifier(table_path[1], self.dialect);
270 for dbs in self.tables.values() {
271 if let Some(info) = dbs.get(&norm_db).and_then(|tbls| tbls.get(&norm_tbl)) {
272 return Some(info);
273 }
274 }
275 }
276
277 None
278 }
279
280 fn resolve_path(&self, table_path: &[&str]) -> Result<(String, String, String), SchemaError> {
283 match table_path.len() {
284 1 => Ok((
285 String::new(),
286 String::new(),
287 normalize_identifier(table_path[0], self.dialect),
288 )),
289 2 => Ok((
290 String::new(),
291 normalize_identifier(table_path[0], self.dialect),
292 normalize_identifier(table_path[1], self.dialect),
293 )),
294 3 => Ok((
295 normalize_identifier(table_path[0], self.dialect),
296 normalize_identifier(table_path[1], self.dialect),
297 normalize_identifier(table_path[2], self.dialect),
298 )),
299 _ => Err(SchemaError::TableNotFound(table_path.join("."))),
300 }
301 }
302
303 fn format_table_path(table_path: &[&str]) -> String {
304 table_path.join(".")
305 }
306}
307
308impl Schema for MappingSchema {
309 fn add_table(
310 &mut self,
311 table_path: &[&str],
312 columns: Vec<(String, DataType)>,
313 ) -> Result<(), SchemaError> {
314 let (catalog, database, table) = self.resolve_path(table_path)?;
315 let entry = self
316 .tables
317 .entry(catalog)
318 .or_default()
319 .entry(database)
320 .or_default();
321
322 if entry.contains_key(&table) {
323 return Err(SchemaError::DuplicateTable(
324 Self::format_table_path(table_path),
325 ));
326 }
327
328 let info = ColumnInfo::new(columns, self.dialect);
329 entry.insert(table, info);
330 Ok(())
331 }
332
333 fn column_names(&self, table_path: &[&str]) -> Result<Vec<String>, SchemaError> {
334 self.find_table(table_path)
335 .map(|info| info.column_names())
336 .ok_or_else(|| SchemaError::TableNotFound(Self::format_table_path(table_path)))
337 }
338
339 fn get_column_type(
340 &self,
341 table_path: &[&str],
342 column: &str,
343 ) -> Result<DataType, SchemaError> {
344 let table_str = Self::format_table_path(table_path);
345 let info = self
346 .find_table(table_path)
347 .ok_or_else(|| SchemaError::TableNotFound(table_str.clone()))?;
348
349 info.get_type(column, self.dialect)
350 .cloned()
351 .ok_or(SchemaError::ColumnNotFound {
352 table: table_str,
353 column: column.to_string(),
354 })
355 }
356
357 fn has_column(&self, table_path: &[&str], column: &str) -> bool {
358 self.find_table(table_path)
359 .is_some_and(|info| info.has_column(column, self.dialect))
360 }
361
362 fn dialect(&self) -> Dialect {
363 self.dialect
364 }
365}
366
367#[must_use]
378pub fn normalize_identifier(name: &str, dialect: Dialect) -> String {
379 if is_case_sensitive_dialect(dialect) {
380 name.to_string()
381 } else {
382 name.to_lowercase()
383 }
384}
385
386#[must_use]
388pub fn is_case_sensitive_dialect(dialect: Dialect) -> bool {
389 matches!(
390 dialect,
391 Dialect::BigQuery | Dialect::Hive | Dialect::Spark | Dialect::Databricks
392 )
393}
394
395pub fn ensure_schema(
422 tables: HashMap<String, HashMap<String, DataType>>,
423 dialect: Dialect,
424) -> MappingSchema {
425 let mut schema = MappingSchema::new(dialect);
426 for (table_name, columns) in tables {
427 let col_vec: Vec<(String, DataType)> = columns.into_iter().collect();
428 let _ = schema.replace_table(&[&table_name], col_vec);
430 }
431 schema
432}
433
434pub type CatalogMap = HashMap<String, HashMap<String, HashMap<String, HashMap<String, DataType>>>>;
437
438pub fn ensure_schema_nested(
441 catalog_map: CatalogMap,
442 dialect: Dialect,
443) -> MappingSchema {
444 let mut schema = MappingSchema::new(dialect);
445 for (catalog, databases) in catalog_map {
446 for (database, tables) in databases {
447 for (table, columns) in tables {
448 let col_vec: Vec<(String, DataType)> = columns.into_iter().collect();
449 let _ = schema.replace_table(
450 &[&catalog, &database, &table],
451 col_vec,
452 );
453 }
454 }
455 }
456 schema
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462
463 #[test]
466 fn test_add_and_query_table() {
467 let mut schema = MappingSchema::new(Dialect::Ansi);
468 schema
469 .add_table(
470 &["users"],
471 vec![
472 ("id".to_string(), DataType::Int),
473 ("name".to_string(), DataType::Varchar(Some(255))),
474 ("email".to_string(), DataType::Text),
475 ],
476 )
477 .unwrap();
478
479 assert_eq!(
480 schema.column_names(&["users"]).unwrap(),
481 vec!["id", "name", "email"]
482 );
483 assert_eq!(
484 schema.get_column_type(&["users"], "id").unwrap(),
485 DataType::Int
486 );
487 assert_eq!(
488 schema.get_column_type(&["users"], "name").unwrap(),
489 DataType::Varchar(Some(255))
490 );
491 assert!(schema.has_column(&["users"], "id"));
492 assert!(schema.has_column(&["users"], "email"));
493 assert!(!schema.has_column(&["users"], "nonexistent"));
494 }
495
496 #[test]
497 fn test_duplicate_table_error() {
498 let mut schema = MappingSchema::new(Dialect::Ansi);
499 schema
500 .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
501 .unwrap();
502
503 let err = schema
504 .add_table(&["t"], vec![("b".to_string(), DataType::Text)])
505 .unwrap_err();
506 assert!(matches!(err, SchemaError::DuplicateTable(_)));
507 }
508
509 #[test]
510 fn test_replace_table() {
511 let mut schema = MappingSchema::new(Dialect::Ansi);
512 schema
513 .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
514 .unwrap();
515
516 schema
517 .replace_table(&["t"], vec![("b".to_string(), DataType::Text)])
518 .unwrap();
519
520 assert_eq!(schema.column_names(&["t"]).unwrap(), vec!["b"]);
521 assert_eq!(
522 schema.get_column_type(&["t"], "b").unwrap(),
523 DataType::Text
524 );
525 }
526
527 #[test]
528 fn test_remove_table() {
529 let mut schema = MappingSchema::new(Dialect::Ansi);
530 schema
531 .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
532 .unwrap();
533
534 assert!(schema.remove_table(&["t"]).unwrap());
535 assert!(!schema.remove_table(&["t"]).unwrap());
536 assert!(schema.column_names(&["t"]).is_err());
537 }
538
539 #[test]
540 fn test_table_not_found() {
541 let schema = MappingSchema::new(Dialect::Ansi);
542 let err = schema.column_names(&["nonexistent"]).unwrap_err();
543 assert!(matches!(err, SchemaError::TableNotFound(_)));
544 }
545
546 #[test]
547 fn test_column_not_found() {
548 let mut schema = MappingSchema::new(Dialect::Ansi);
549 schema
550 .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
551 .unwrap();
552
553 let err = schema.get_column_type(&["t"], "z").unwrap_err();
554 assert!(matches!(err, SchemaError::ColumnNotFound { .. }));
555 }
556
557 #[test]
560 fn test_three_level_path() {
561 let mut schema = MappingSchema::new(Dialect::Ansi);
562 schema
563 .add_table(
564 &["my_catalog", "my_db", "orders"],
565 vec![
566 ("order_id".to_string(), DataType::BigInt),
567 ("total".to_string(), DataType::Decimal {
568 precision: Some(10),
569 scale: Some(2),
570 }),
571 ],
572 )
573 .unwrap();
574
575 assert_eq!(
576 schema
577 .column_names(&["my_catalog", "my_db", "orders"])
578 .unwrap(),
579 vec!["order_id", "total"]
580 );
581 assert!(schema.has_column(&["my_catalog", "my_db", "orders"], "order_id"));
582 }
583
584 #[test]
585 fn test_two_level_path() {
586 let mut schema = MappingSchema::new(Dialect::Ansi);
587 schema
588 .add_table(
589 &["public", "users"],
590 vec![("id".to_string(), DataType::Int)],
591 )
592 .unwrap();
593
594 assert_eq!(
595 schema.column_names(&["public", "users"]).unwrap(),
596 vec!["id"]
597 );
598 }
599
600 #[test]
601 fn test_short_path_searches_all() {
602 let mut schema = MappingSchema::new(Dialect::Ansi);
603 schema
604 .add_table(
605 &["catalog", "db", "orders"],
606 vec![("id".to_string(), DataType::Int)],
607 )
608 .unwrap();
609
610 assert!(schema.has_column(&["orders"], "id"));
612 assert_eq!(schema.column_names(&["orders"]).unwrap(), vec!["id"]);
613
614 assert!(schema.has_column(&["db", "orders"], "id"));
616 }
617
618 #[test]
621 fn test_case_insensitive_dialect() {
622 let mut schema = MappingSchema::new(Dialect::Postgres);
623 schema
624 .add_table(
625 &["Users"],
626 vec![("ID".to_string(), DataType::Int)],
627 )
628 .unwrap();
629
630 assert!(schema.has_column(&["users"], "id"));
632 assert!(schema.has_column(&["USERS"], "ID"));
633 assert!(schema.has_column(&["Users"], "Id"));
634 assert_eq!(
635 schema.get_column_type(&["users"], "id").unwrap(),
636 DataType::Int
637 );
638 }
639
640 #[test]
641 fn test_case_sensitive_dialect() {
642 let mut schema = MappingSchema::new(Dialect::BigQuery);
643 schema
644 .add_table(
645 &["Users"],
646 vec![("ID".to_string(), DataType::Int)],
647 )
648 .unwrap();
649
650 assert!(schema.has_column(&["Users"], "ID"));
652 assert!(!schema.has_column(&["users"], "ID"));
653 assert!(!schema.has_column(&["Users"], "id"));
654 }
655
656 #[test]
657 fn test_hive_case_sensitive() {
658 let mut schema = MappingSchema::new(Dialect::Hive);
659 schema
660 .add_table(
661 &["MyTable"],
662 vec![("Col1".to_string(), DataType::Text)],
663 )
664 .unwrap();
665
666 assert!(schema.has_column(&["MyTable"], "Col1"));
667 assert!(!schema.has_column(&["mytable"], "col1"));
668 }
669
670 #[test]
673 fn test_udf_return_type() {
674 let mut schema = MappingSchema::new(Dialect::Ansi);
675 schema.add_udf("my_custom_fn", DataType::Int);
676
677 assert_eq!(
678 schema.get_udf_type("my_custom_fn").unwrap(),
679 &DataType::Int
680 );
681 assert_eq!(
683 schema.get_udf_type("MY_CUSTOM_FN").unwrap(),
684 &DataType::Int
685 );
686 assert!(schema.get_udf_type("nonexistent").is_none());
687 }
688
689 #[test]
690 fn test_udf_case_sensitive() {
691 let mut schema = MappingSchema::new(Dialect::BigQuery);
692 schema.add_udf("myFunc", DataType::Boolean);
693
694 assert!(schema.get_udf_type("myFunc").is_some());
695 assert!(schema.get_udf_type("MYFUNC").is_none());
696 }
697
698 #[test]
701 fn test_ensure_schema() {
702 let mut tables = HashMap::new();
703 let mut cols = HashMap::new();
704 cols.insert("id".to_string(), DataType::Int);
705 cols.insert("name".to_string(), DataType::Text);
706 tables.insert("users".to_string(), cols);
707
708 let schema = ensure_schema(tables, Dialect::Postgres);
709 assert!(schema.has_column(&["users"], "id"));
710 assert!(schema.has_column(&["users"], "name"));
711 }
712
713 #[test]
714 fn test_ensure_schema_nested() {
715 let mut catalogs = HashMap::new();
716 let mut databases = HashMap::new();
717 let mut tables = HashMap::new();
718 let mut cols = HashMap::new();
719 cols.insert("order_id".to_string(), DataType::BigInt);
720 tables.insert("orders".to_string(), cols);
721 databases.insert("sales".to_string(), tables);
722 catalogs.insert("warehouse".to_string(), databases);
723
724 let schema = ensure_schema_nested(catalogs, Dialect::Ansi);
725 assert!(schema.has_column(&["warehouse", "sales", "orders"], "order_id"));
726 assert!(schema.has_column(&["orders"], "order_id"));
728 }
729
730 #[test]
733 fn test_table_names() {
734 let mut schema = MappingSchema::new(Dialect::Ansi);
735 schema
736 .add_table(
737 &["cat", "db", "t1"],
738 vec![("a".to_string(), DataType::Int)],
739 )
740 .unwrap();
741 schema
742 .add_table(
743 &["cat", "db", "t2"],
744 vec![("b".to_string(), DataType::Int)],
745 )
746 .unwrap();
747
748 let mut names = schema.table_names();
749 names.sort();
750 assert_eq!(names.len(), 2);
751 assert!(names.iter().any(|(c, d, t)| c == "cat" && d == "db" && t == "t1"));
752 assert!(names.iter().any(|(c, d, t)| c == "cat" && d == "db" && t == "t2"));
753 }
754
755 #[test]
758 fn test_invalid_path_too_many_parts() {
759 let mut schema = MappingSchema::new(Dialect::Ansi);
760 let err = schema
761 .add_table(
762 &["a", "b", "c", "d"],
763 vec![("x".to_string(), DataType::Int)],
764 )
765 .unwrap_err();
766 assert!(matches!(err, SchemaError::TableNotFound(_)));
767 }
768
769 #[test]
770 fn test_empty_schema_has_no_columns() {
771 let schema = MappingSchema::new(Dialect::Ansi);
772 assert!(!schema.has_column(&["any_table"], "any_col"));
773 }
774
775 #[test]
778 fn test_schema_error_display() {
779 let e = SchemaError::TableNotFound("users".to_string());
780 assert_eq!(e.to_string(), "Table not found: users");
781
782 let e = SchemaError::ColumnNotFound {
783 table: "users".to_string(),
784 column: "age".to_string(),
785 };
786 assert_eq!(e.to_string(), "Column 'age' not found in table 'users'");
787
788 let e = SchemaError::DuplicateTable("users".to_string());
789 assert_eq!(e.to_string(), "Table already exists: users");
790 }
791
792 #[test]
795 fn test_schema_error_into_sqlglot_error() {
796 let e: SqlglotError = SchemaError::TableNotFound("t".to_string()).into();
797 assert!(matches!(e, SqlglotError::Internal(_)));
798 }
799
800 #[test]
803 fn test_multiple_dialects_normalization() {
804 let mut pg = MappingSchema::new(Dialect::Postgres);
806 pg.add_table(&["T"], vec![("C".to_string(), DataType::Int)])
807 .unwrap();
808 assert!(pg.has_column(&["t"], "c"));
809
810 let mut my = MappingSchema::new(Dialect::Mysql);
812 my.add_table(&["T"], vec![("C".to_string(), DataType::Int)])
813 .unwrap();
814 assert!(my.has_column(&["t"], "c"));
815
816 let mut sp = MappingSchema::new(Dialect::Spark);
818 sp.add_table(&["T"], vec![("C".to_string(), DataType::Int)])
819 .unwrap();
820 assert!(!sp.has_column(&["t"], "c"));
821 assert!(sp.has_column(&["T"], "C"));
822 }
823
824 #[test]
827 fn test_complex_data_types() {
828 let mut schema = MappingSchema::new(Dialect::Ansi);
829 schema
830 .add_table(
831 &["complex_table"],
832 vec![
833 ("tags".to_string(), DataType::Array(Some(Box::new(DataType::Text)))),
834 ("metadata".to_string(), DataType::Json),
835 ("coords".to_string(), DataType::Struct(vec![
836 ("lat".to_string(), DataType::Double),
837 ("lng".to_string(), DataType::Double),
838 ])),
839 ("lookup".to_string(), DataType::Map {
840 key: Box::new(DataType::Text),
841 value: Box::new(DataType::Int),
842 }),
843 ],
844 )
845 .unwrap();
846
847 assert_eq!(
848 schema.get_column_type(&["complex_table"], "tags").unwrap(),
849 DataType::Array(Some(Box::new(DataType::Text)))
850 );
851 assert_eq!(
852 schema.get_column_type(&["complex_table"], "metadata").unwrap(),
853 DataType::Json
854 );
855 }
856
857 #[test]
860 fn test_schema_dialect() {
861 let schema = MappingSchema::new(Dialect::Snowflake);
862 assert_eq!(schema.dialect(), Dialect::Snowflake);
863 }
864}