Skip to main content

sqlglot_rust/schema/
mod.rs

1//! Schema management system for schema-aware analysis and optimization.
2//!
3//! Provides a [`Schema`] trait and [`MappingSchema`] implementation analogous
4//! to Python sqlglot's `MappingSchema`. This is the foundation for type
5//! annotation, column qualification, projection pushdown, and lineage analysis.
6//!
7//! # Example
8//!
9//! ```rust
10//! use sqlglot_rust::schema::{MappingSchema, Schema};
11//! use sqlglot_rust::ast::DataType;
12//! use sqlglot_rust::Dialect;
13//!
14//! let mut schema = MappingSchema::new(Dialect::Ansi);
15//! schema.add_table(
16//!     &["catalog", "db", "users"],
17//!     vec![
18//!         ("id".to_string(), DataType::Int),
19//!         ("name".to_string(), DataType::Varchar(Some(255))),
20//!         ("email".to_string(), DataType::Text),
21//!     ],
22//! ).unwrap();
23//!
24//! assert_eq!(
25//!     schema.column_names(&["catalog", "db", "users"]).unwrap(),
26//!     vec!["id", "name", "email"],
27//! );
28//! assert_eq!(
29//!     schema.get_column_type(&["catalog", "db", "users"], "id").unwrap(),
30//!     DataType::Int,
31//! );
32//! assert!(schema.has_column(&["catalog", "db", "users"], "id"));
33//! ```
34
35use std::collections::HashMap;
36
37use crate::ast::DataType;
38use crate::dialects::Dialect;
39use crate::errors::SqlglotError;
40
41/// Errors specific to schema operations.
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum SchemaError {
44    /// The referenced table was not found in the schema.
45    TableNotFound(String),
46    /// The referenced column was not found in the table.
47    ColumnNotFound { table: String, column: String },
48    /// Duplicate table registration.
49    DuplicateTable(String),
50}
51
52impl std::fmt::Display for SchemaError {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        match self {
55            SchemaError::TableNotFound(t) => write!(f, "Table not found: {t}"),
56            SchemaError::ColumnNotFound { table, column } => {
57                write!(f, "Column '{column}' not found in table '{table}'")
58            }
59            SchemaError::DuplicateTable(t) => write!(f, "Table already exists: {t}"),
60        }
61    }
62}
63
64impl std::error::Error for SchemaError {}
65
66impl From<SchemaError> for SqlglotError {
67    fn from(e: SchemaError) -> Self {
68        SqlglotError::Internal(e.to_string())
69    }
70}
71
72/// Schema trait for schema-aware analysis and optimization.
73///
74/// Provides methods to query table and column metadata. Implementations
75/// can back this with in-memory mappings, database catalogs, etc.
76pub trait Schema {
77    /// Register a table with its column definitions.
78    ///
79    /// `table_path` is a slice of identifiers representing the fully qualified
80    /// table path: `[catalog, database, table]`, `[database, table]`, or `[table]`.
81    ///
82    /// # Errors
83    ///
84    /// Returns [`SchemaError::DuplicateTable`] if the table is already registered
85    /// (use `replace_table` to overwrite).
86    fn add_table(
87        &mut self,
88        table_path: &[&str],
89        columns: Vec<(String, DataType)>,
90    ) -> Result<(), SchemaError>;
91
92    /// Get the column names for a table, in definition order.
93    ///
94    /// # Errors
95    ///
96    /// Returns [`SchemaError::TableNotFound`] if the table is not registered.
97    fn column_names(&self, table_path: &[&str]) -> Result<Vec<String>, SchemaError>;
98
99    /// Get the data type of a specific column in a table.
100    ///
101    /// # Errors
102    ///
103    /// Returns [`SchemaError::TableNotFound`] or [`SchemaError::ColumnNotFound`].
104    fn get_column_type(&self, table_path: &[&str], column: &str) -> Result<DataType, SchemaError>;
105
106    /// Check whether a column exists in the given table.
107    fn has_column(&self, table_path: &[&str], column: &str) -> bool;
108
109    /// Get the dialect used for identifier normalization.
110    fn dialect(&self) -> Dialect;
111
112    /// Get the return type of a registered UDF (user-defined function).
113    ///
114    /// Returns `None` by default. Implementations that support UDFs should
115    /// override this method.
116    fn get_udf_type(&self, _name: &str) -> Option<&DataType> {
117        None
118    }
119}
120
121/// Column metadata stored inside a [`MappingSchema`].
122#[derive(Debug, Clone, PartialEq)]
123struct ColumnInfo {
124    /// Original insertion order for stable column listing.
125    columns: Vec<(String, DataType)>,
126    /// Fast lookup by normalized column name → index into `columns`.
127    index: HashMap<String, usize>,
128}
129
130impl ColumnInfo {
131    fn new(columns: Vec<(String, DataType)>, dialect: Dialect) -> Self {
132        let index = columns
133            .iter()
134            .enumerate()
135            .map(|(i, (name, _))| (normalize_identifier(name, dialect), i))
136            .collect();
137        Self { columns, index }
138    }
139
140    fn column_names(&self) -> Vec<String> {
141        self.columns.iter().map(|(n, _)| n.clone()).collect()
142    }
143
144    fn get_type(&self, column: &str, dialect: Dialect) -> Option<&DataType> {
145        let key = normalize_identifier(column, dialect);
146        self.index.get(&key).map(|&i| &self.columns[i].1)
147    }
148
149    fn has_column(&self, column: &str, dialect: Dialect) -> bool {
150        let key = normalize_identifier(column, dialect);
151        self.index.contains_key(&key)
152    }
153}
154
155/// A schema backed by in-memory hash maps, supporting 3-level nesting:
156/// `catalog → database → table → column → type`.
157///
158/// Analogous to Python sqlglot's `MappingSchema`.
159///
160/// Identifiers are normalized according to the configured dialect:
161/// - Case-insensitive dialects (most): identifiers are lowercased for lookup.
162/// - Case-sensitive dialects (e.g. BigQuery, Hive): identifiers are kept as-is.
163/// - Quoted identifiers are always stored verbatim (not normalized).
164#[derive(Debug, Clone)]
165pub struct MappingSchema {
166    dialect: Dialect,
167    /// Nested map: catalog → database → table → ColumnInfo
168    tables: HashMap<String, HashMap<String, HashMap<String, ColumnInfo>>>,
169    /// UDF return type mappings: function_name (normalized) → return type
170    udf_types: HashMap<String, DataType>,
171}
172
173impl MappingSchema {
174    /// Create a new empty schema for the given dialect.
175    #[must_use]
176    pub fn new(dialect: Dialect) -> Self {
177        Self {
178            dialect,
179            tables: HashMap::new(),
180            udf_types: HashMap::new(),
181        }
182    }
183
184    /// Replace a table if it already exists, or add it if it doesn't.
185    pub fn replace_table(
186        &mut self,
187        table_path: &[&str],
188        columns: Vec<(String, DataType)>,
189    ) -> Result<(), SchemaError> {
190        let (catalog, database, table) = self.resolve_path(table_path)?;
191        let info = ColumnInfo::new(columns, self.dialect);
192        self.tables
193            .entry(catalog)
194            .or_default()
195            .entry(database)
196            .or_default()
197            .insert(table, info);
198        Ok(())
199    }
200
201    /// Remove a table from the schema. Returns `true` if the table existed.
202    pub fn remove_table(&mut self, table_path: &[&str]) -> Result<bool, SchemaError> {
203        let (catalog, database, table) = self.resolve_path(table_path)?;
204        let removed = self
205            .tables
206            .get_mut(&catalog)
207            .and_then(|dbs| dbs.get_mut(&database))
208            .map(|tbls| tbls.remove(&table).is_some())
209            .unwrap_or(false);
210        Ok(removed)
211    }
212
213    /// Register a UDF (user-defined function) with its return type.
214    pub fn add_udf(&mut self, name: &str, return_type: DataType) {
215        let key = normalize_identifier(name, self.dialect);
216        self.udf_types.insert(key, return_type);
217    }
218
219    /// Get the return type of a registered UDF.
220    #[must_use]
221    pub fn get_udf_type(&self, name: &str) -> Option<&DataType> {
222        let key = normalize_identifier(name, self.dialect);
223        self.udf_types.get(&key)
224    }
225
226    /// List all registered tables as `(catalog, database, table)` triples.
227    #[must_use]
228    pub fn table_names(&self) -> Vec<(String, String, String)> {
229        let mut result = Vec::new();
230        for (catalog, dbs) in &self.tables {
231            for (database, tbls) in dbs {
232                for table in tbls.keys() {
233                    result.push((catalog.clone(), database.clone(), table.clone()));
234                }
235            }
236        }
237        result
238    }
239
240    /// Find a table across all catalogs/databases when only a short path is given.
241    /// Returns the first match found (useful for unqualified table references).
242    fn find_table(&self, table_path: &[&str]) -> Option<&ColumnInfo> {
243        let (catalog, database, table) = match self.resolve_path(table_path) {
244            Ok(parts) => parts,
245            Err(_) => return None,
246        };
247
248        // Exact match first
249        if let Some(info) = self
250            .tables
251            .get(&catalog)
252            .and_then(|dbs| dbs.get(&database))
253            .and_then(|tbls| tbls.get(&table))
254        {
255            return Some(info);
256        }
257
258        // For single-name lookups, search all catalogs/databases
259        if table_path.len() == 1 {
260            let norm_name = normalize_identifier(table_path[0], self.dialect);
261            for dbs in self.tables.values() {
262                for tbls in dbs.values() {
263                    if let Some(info) = tbls.get(&norm_name) {
264                        return Some(info);
265                    }
266                }
267            }
268        }
269
270        // For 2-part lookups (db.table), search all catalogs
271        if table_path.len() == 2 {
272            let norm_db = normalize_identifier(table_path[0], self.dialect);
273            let norm_tbl = normalize_identifier(table_path[1], self.dialect);
274            for dbs in self.tables.values() {
275                if let Some(info) = dbs.get(&norm_db).and_then(|tbls| tbls.get(&norm_tbl)) {
276                    return Some(info);
277                }
278            }
279        }
280
281        None
282    }
283
284    /// Resolve a table path into normalized (catalog, database, table) parts,
285    /// filling in defaults for missing levels.
286    fn resolve_path(&self, table_path: &[&str]) -> Result<(String, String, String), SchemaError> {
287        match table_path.len() {
288            1 => Ok((
289                String::new(),
290                String::new(),
291                normalize_identifier(table_path[0], self.dialect),
292            )),
293            2 => Ok((
294                String::new(),
295                normalize_identifier(table_path[0], self.dialect),
296                normalize_identifier(table_path[1], self.dialect),
297            )),
298            3 => Ok((
299                normalize_identifier(table_path[0], self.dialect),
300                normalize_identifier(table_path[1], self.dialect),
301                normalize_identifier(table_path[2], self.dialect),
302            )),
303            _ => Err(SchemaError::TableNotFound(table_path.join("."))),
304        }
305    }
306
307    fn format_table_path(table_path: &[&str]) -> String {
308        table_path.join(".")
309    }
310}
311
312impl Schema for MappingSchema {
313    fn add_table(
314        &mut self,
315        table_path: &[&str],
316        columns: Vec<(String, DataType)>,
317    ) -> Result<(), SchemaError> {
318        let (catalog, database, table) = self.resolve_path(table_path)?;
319        let entry = self
320            .tables
321            .entry(catalog)
322            .or_default()
323            .entry(database)
324            .or_default();
325
326        if entry.contains_key(&table) {
327            return Err(SchemaError::DuplicateTable(Self::format_table_path(
328                table_path,
329            )));
330        }
331
332        let info = ColumnInfo::new(columns, self.dialect);
333        entry.insert(table, info);
334        Ok(())
335    }
336
337    fn column_names(&self, table_path: &[&str]) -> Result<Vec<String>, SchemaError> {
338        self.find_table(table_path)
339            .map(|info| info.column_names())
340            .ok_or_else(|| SchemaError::TableNotFound(Self::format_table_path(table_path)))
341    }
342
343    fn get_column_type(&self, table_path: &[&str], column: &str) -> Result<DataType, SchemaError> {
344        let table_str = Self::format_table_path(table_path);
345        let info = self
346            .find_table(table_path)
347            .ok_or_else(|| SchemaError::TableNotFound(table_str.clone()))?;
348
349        info.get_type(column, self.dialect)
350            .cloned()
351            .ok_or(SchemaError::ColumnNotFound {
352                table: table_str,
353                column: column.to_string(),
354            })
355    }
356
357    fn has_column(&self, table_path: &[&str], column: &str) -> bool {
358        self.find_table(table_path)
359            .is_some_and(|info| info.has_column(column, self.dialect))
360    }
361
362    fn dialect(&self) -> Dialect {
363        self.dialect
364    }
365
366    fn get_udf_type(&self, name: &str) -> Option<&DataType> {
367        let key = normalize_identifier(name, self.dialect);
368        self.udf_types.get(&key)
369    }
370}
371
372// ═══════════════════════════════════════════════════════════════════════
373// Identifier normalization
374// ═══════════════════════════════════════════════════════════════════════
375
376/// Normalize an identifier according to the given dialect's conventions.
377///
378/// Most SQL dialects treat unquoted identifiers as case-insensitive
379/// (typically by converting to lowercase internally). A few dialects
380/// (BigQuery, Hive, Spark, Databricks) are case-sensitive for table/column
381/// names.
382#[must_use]
383pub fn normalize_identifier(name: &str, dialect: Dialect) -> String {
384    if is_case_sensitive_dialect(dialect) {
385        name.to_string()
386    } else {
387        name.to_lowercase()
388    }
389}
390
391/// Returns `true` if the dialect treats unquoted identifiers as case-sensitive.
392#[must_use]
393pub fn is_case_sensitive_dialect(dialect: Dialect) -> bool {
394    matches!(
395        dialect,
396        Dialect::BigQuery | Dialect::Hive | Dialect::Spark | Dialect::Databricks
397    )
398}
399
400// ═══════════════════════════════════════════════════════════════════════
401// Helper: build schema from nested maps
402// ═══════════════════════════════════════════════════════════════════════
403
404/// Build a [`MappingSchema`] from a nested map structure.
405///
406/// The input maps from table name → column name → data type, mirroring
407/// the common pattern of constructing schemas from DDL or metadata.
408///
409/// # Example
410///
411/// ```rust
412/// use std::collections::HashMap;
413/// use sqlglot_rust::schema::{ensure_schema, Schema};
414/// use sqlglot_rust::ast::DataType;
415/// use sqlglot_rust::Dialect;
416///
417/// let mut tables = HashMap::new();
418/// let mut columns = HashMap::new();
419/// columns.insert("id".to_string(), DataType::Int);
420/// columns.insert("name".to_string(), DataType::Varchar(Some(255)));
421/// tables.insert("users".to_string(), columns);
422///
423/// let schema = ensure_schema(tables, Dialect::Postgres);
424/// assert!(schema.has_column(&["users"], "id"));
425/// ```
426pub fn ensure_schema(
427    tables: HashMap<String, HashMap<String, DataType>>,
428    dialect: Dialect,
429) -> MappingSchema {
430    let mut schema = MappingSchema::new(dialect);
431    for (table_name, columns) in tables {
432        let col_vec: Vec<(String, DataType)> = columns.into_iter().collect();
433        // Use replace_table to avoid DuplicateTable errors
434        let _ = schema.replace_table(&[&table_name], col_vec);
435    }
436    schema
437}
438
439/// Type alias for the 3-level nested schema map:
440/// `catalog → database → table → column → type`.
441pub type CatalogMap = HashMap<String, HashMap<String, HashMap<String, HashMap<String, DataType>>>>;
442
443/// Build a [`MappingSchema`] from a 3-level nested map:
444/// `catalog → database → table → column → type`.
445pub fn ensure_schema_nested(catalog_map: CatalogMap, dialect: Dialect) -> MappingSchema {
446    let mut schema = MappingSchema::new(dialect);
447    for (catalog, databases) in catalog_map {
448        for (database, tables) in databases {
449            for (table, columns) in tables {
450                let col_vec: Vec<(String, DataType)> = columns.into_iter().collect();
451                let _ = schema.replace_table(&[&catalog, &database, &table], col_vec);
452            }
453        }
454    }
455    schema
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    // ── Basic operations ────────────────────────────────────────────────
463
464    #[test]
465    fn test_add_and_query_table() {
466        let mut schema = MappingSchema::new(Dialect::Ansi);
467        schema
468            .add_table(
469                &["users"],
470                vec![
471                    ("id".to_string(), DataType::Int),
472                    ("name".to_string(), DataType::Varchar(Some(255))),
473                    ("email".to_string(), DataType::Text),
474                ],
475            )
476            .unwrap();
477
478        assert_eq!(
479            schema.column_names(&["users"]).unwrap(),
480            vec!["id", "name", "email"]
481        );
482        assert_eq!(
483            schema.get_column_type(&["users"], "id").unwrap(),
484            DataType::Int
485        );
486        assert_eq!(
487            schema.get_column_type(&["users"], "name").unwrap(),
488            DataType::Varchar(Some(255))
489        );
490        assert!(schema.has_column(&["users"], "id"));
491        assert!(schema.has_column(&["users"], "email"));
492        assert!(!schema.has_column(&["users"], "nonexistent"));
493    }
494
495    #[test]
496    fn test_duplicate_table_error() {
497        let mut schema = MappingSchema::new(Dialect::Ansi);
498        schema
499            .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
500            .unwrap();
501
502        let err = schema
503            .add_table(&["t"], vec![("b".to_string(), DataType::Text)])
504            .unwrap_err();
505        assert!(matches!(err, SchemaError::DuplicateTable(_)));
506    }
507
508    #[test]
509    fn test_replace_table() {
510        let mut schema = MappingSchema::new(Dialect::Ansi);
511        schema
512            .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
513            .unwrap();
514
515        schema
516            .replace_table(&["t"], vec![("b".to_string(), DataType::Text)])
517            .unwrap();
518
519        assert_eq!(schema.column_names(&["t"]).unwrap(), vec!["b"]);
520        assert_eq!(schema.get_column_type(&["t"], "b").unwrap(), DataType::Text);
521    }
522
523    #[test]
524    fn test_remove_table() {
525        let mut schema = MappingSchema::new(Dialect::Ansi);
526        schema
527            .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
528            .unwrap();
529
530        assert!(schema.remove_table(&["t"]).unwrap());
531        assert!(!schema.remove_table(&["t"]).unwrap());
532        assert!(schema.column_names(&["t"]).is_err());
533    }
534
535    #[test]
536    fn test_table_not_found() {
537        let schema = MappingSchema::new(Dialect::Ansi);
538        let err = schema.column_names(&["nonexistent"]).unwrap_err();
539        assert!(matches!(err, SchemaError::TableNotFound(_)));
540    }
541
542    #[test]
543    fn test_column_not_found() {
544        let mut schema = MappingSchema::new(Dialect::Ansi);
545        schema
546            .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
547            .unwrap();
548
549        let err = schema.get_column_type(&["t"], "z").unwrap_err();
550        assert!(matches!(err, SchemaError::ColumnNotFound { .. }));
551    }
552
553    // ── Multi-level nesting ─────────────────────────────────────────────
554
555    #[test]
556    fn test_three_level_path() {
557        let mut schema = MappingSchema::new(Dialect::Ansi);
558        schema
559            .add_table(
560                &["my_catalog", "my_db", "orders"],
561                vec![
562                    ("order_id".to_string(), DataType::BigInt),
563                    (
564                        "total".to_string(),
565                        DataType::Decimal {
566                            precision: Some(10),
567                            scale: Some(2),
568                        },
569                    ),
570                ],
571            )
572            .unwrap();
573
574        assert_eq!(
575            schema
576                .column_names(&["my_catalog", "my_db", "orders"])
577                .unwrap(),
578            vec!["order_id", "total"]
579        );
580        assert!(schema.has_column(&["my_catalog", "my_db", "orders"], "order_id"));
581    }
582
583    #[test]
584    fn test_two_level_path() {
585        let mut schema = MappingSchema::new(Dialect::Ansi);
586        schema
587            .add_table(
588                &["public", "users"],
589                vec![("id".to_string(), DataType::Int)],
590            )
591            .unwrap();
592
593        assert_eq!(
594            schema.column_names(&["public", "users"]).unwrap(),
595            vec!["id"]
596        );
597    }
598
599    #[test]
600    fn test_short_path_searches_all() {
601        let mut schema = MappingSchema::new(Dialect::Ansi);
602        schema
603            .add_table(
604                &["catalog", "db", "orders"],
605                vec![("id".to_string(), DataType::Int)],
606            )
607            .unwrap();
608
609        // Single-name lookup should find it
610        assert!(schema.has_column(&["orders"], "id"));
611        assert_eq!(schema.column_names(&["orders"]).unwrap(), vec!["id"]);
612
613        // Two-part lookup should find it
614        assert!(schema.has_column(&["db", "orders"], "id"));
615    }
616
617    // ── Dialect-aware normalization ─────────────────────────────────────
618
619    #[test]
620    fn test_case_insensitive_dialect() {
621        let mut schema = MappingSchema::new(Dialect::Postgres);
622        schema
623            .add_table(&["Users"], vec![("ID".to_string(), DataType::Int)])
624            .unwrap();
625
626        // Lookups should be case-insensitive
627        assert!(schema.has_column(&["users"], "id"));
628        assert!(schema.has_column(&["USERS"], "ID"));
629        assert!(schema.has_column(&["Users"], "Id"));
630        assert_eq!(
631            schema.get_column_type(&["users"], "id").unwrap(),
632            DataType::Int
633        );
634    }
635
636    #[test]
637    fn test_case_sensitive_dialect() {
638        let mut schema = MappingSchema::new(Dialect::BigQuery);
639        schema
640            .add_table(&["Users"], vec![("ID".to_string(), DataType::Int)])
641            .unwrap();
642
643        // BigQuery is case-sensitive
644        assert!(schema.has_column(&["Users"], "ID"));
645        assert!(!schema.has_column(&["users"], "ID"));
646        assert!(!schema.has_column(&["Users"], "id"));
647    }
648
649    #[test]
650    fn test_hive_case_sensitive() {
651        let mut schema = MappingSchema::new(Dialect::Hive);
652        schema
653            .add_table(&["MyTable"], vec![("Col1".to_string(), DataType::Text)])
654            .unwrap();
655
656        assert!(schema.has_column(&["MyTable"], "Col1"));
657        assert!(!schema.has_column(&["mytable"], "col1"));
658    }
659
660    // ── UDF return types ────────────────────────────────────────────────
661
662    #[test]
663    fn test_udf_return_type() {
664        let mut schema = MappingSchema::new(Dialect::Ansi);
665        schema.add_udf("my_custom_fn", DataType::Int);
666
667        assert_eq!(schema.get_udf_type("my_custom_fn").unwrap(), &DataType::Int);
668        // Case-insensitive for ANSI
669        assert_eq!(schema.get_udf_type("MY_CUSTOM_FN").unwrap(), &DataType::Int);
670        assert!(schema.get_udf_type("nonexistent").is_none());
671    }
672
673    #[test]
674    fn test_udf_case_sensitive() {
675        let mut schema = MappingSchema::new(Dialect::BigQuery);
676        schema.add_udf("myFunc", DataType::Boolean);
677
678        assert!(schema.get_udf_type("myFunc").is_some());
679        assert!(schema.get_udf_type("MYFUNC").is_none());
680    }
681
682    // ── ensure_schema helpers ───────────────────────────────────────────
683
684    #[test]
685    fn test_ensure_schema() {
686        let mut tables = HashMap::new();
687        let mut cols = HashMap::new();
688        cols.insert("id".to_string(), DataType::Int);
689        cols.insert("name".to_string(), DataType::Text);
690        tables.insert("users".to_string(), cols);
691
692        let schema = ensure_schema(tables, Dialect::Postgres);
693        assert!(schema.has_column(&["users"], "id"));
694        assert!(schema.has_column(&["users"], "name"));
695    }
696
697    #[test]
698    fn test_ensure_schema_nested() {
699        let mut catalogs = HashMap::new();
700        let mut databases = HashMap::new();
701        let mut tables = HashMap::new();
702        let mut cols = HashMap::new();
703        cols.insert("order_id".to_string(), DataType::BigInt);
704        tables.insert("orders".to_string(), cols);
705        databases.insert("sales".to_string(), tables);
706        catalogs.insert("warehouse".to_string(), databases);
707
708        let schema = ensure_schema_nested(catalogs, Dialect::Ansi);
709        assert!(schema.has_column(&["warehouse", "sales", "orders"], "order_id"));
710        // Short-path lookup
711        assert!(schema.has_column(&["orders"], "order_id"));
712    }
713
714    // ── table_names listing ─────────────────────────────────────────────
715
716    #[test]
717    fn test_table_names() {
718        let mut schema = MappingSchema::new(Dialect::Ansi);
719        schema
720            .add_table(&["cat", "db", "t1"], vec![("a".to_string(), DataType::Int)])
721            .unwrap();
722        schema
723            .add_table(&["cat", "db", "t2"], vec![("b".to_string(), DataType::Int)])
724            .unwrap();
725
726        let mut names = schema.table_names();
727        names.sort();
728        assert_eq!(names.len(), 2);
729        assert!(
730            names
731                .iter()
732                .any(|(c, d, t)| c == "cat" && d == "db" && t == "t1")
733        );
734        assert!(
735            names
736                .iter()
737                .any(|(c, d, t)| c == "cat" && d == "db" && t == "t2")
738        );
739    }
740
741    // ── Invalid path ────────────────────────────────────────────────────
742
743    #[test]
744    fn test_invalid_path_too_many_parts() {
745        let mut schema = MappingSchema::new(Dialect::Ansi);
746        let err = schema
747            .add_table(
748                &["a", "b", "c", "d"],
749                vec![("x".to_string(), DataType::Int)],
750            )
751            .unwrap_err();
752        assert!(matches!(err, SchemaError::TableNotFound(_)));
753    }
754
755    #[test]
756    fn test_empty_schema_has_no_columns() {
757        let schema = MappingSchema::new(Dialect::Ansi);
758        assert!(!schema.has_column(&["any_table"], "any_col"));
759    }
760
761    // ── Schema error display ────────────────────────────────────────────
762
763    #[test]
764    fn test_schema_error_display() {
765        let e = SchemaError::TableNotFound("users".to_string());
766        assert_eq!(e.to_string(), "Table not found: users");
767
768        let e = SchemaError::ColumnNotFound {
769            table: "users".to_string(),
770            column: "age".to_string(),
771        };
772        assert_eq!(e.to_string(), "Column 'age' not found in table 'users'");
773
774        let e = SchemaError::DuplicateTable("users".to_string());
775        assert_eq!(e.to_string(), "Table already exists: users");
776    }
777
778    // ── SchemaError → SqlglotError conversion ───────────────────────────
779
780    #[test]
781    fn test_schema_error_into_sqlglot_error() {
782        let e: SqlglotError = SchemaError::TableNotFound("t".to_string()).into();
783        assert!(matches!(e, SqlglotError::Internal(_)));
784    }
785
786    // ── Multiple dialects ───────────────────────────────────────────────
787
788    #[test]
789    fn test_multiple_dialects_normalization() {
790        // Postgres: case-insensitive
791        let mut pg = MappingSchema::new(Dialect::Postgres);
792        pg.add_table(&["T"], vec![("C".to_string(), DataType::Int)])
793            .unwrap();
794        assert!(pg.has_column(&["t"], "c"));
795
796        // MySQL: case-insensitive
797        let mut my = MappingSchema::new(Dialect::Mysql);
798        my.add_table(&["T"], vec![("C".to_string(), DataType::Int)])
799            .unwrap();
800        assert!(my.has_column(&["t"], "c"));
801
802        // Spark: case-sensitive
803        let mut sp = MappingSchema::new(Dialect::Spark);
804        sp.add_table(&["T"], vec![("C".to_string(), DataType::Int)])
805            .unwrap();
806        assert!(!sp.has_column(&["t"], "c"));
807        assert!(sp.has_column(&["T"], "C"));
808    }
809
810    // ── Complex data types ──────────────────────────────────────────────
811
812    #[test]
813    fn test_complex_data_types() {
814        let mut schema = MappingSchema::new(Dialect::Ansi);
815        schema
816            .add_table(
817                &["complex_table"],
818                vec![
819                    (
820                        "tags".to_string(),
821                        DataType::Array(Some(Box::new(DataType::Text))),
822                    ),
823                    ("metadata".to_string(), DataType::Json),
824                    (
825                        "coords".to_string(),
826                        DataType::Struct(vec![
827                            ("lat".to_string(), DataType::Double),
828                            ("lng".to_string(), DataType::Double),
829                        ]),
830                    ),
831                    (
832                        "lookup".to_string(),
833                        DataType::Map {
834                            key: Box::new(DataType::Text),
835                            value: Box::new(DataType::Int),
836                        },
837                    ),
838                ],
839            )
840            .unwrap();
841
842        assert_eq!(
843            schema.get_column_type(&["complex_table"], "tags").unwrap(),
844            DataType::Array(Some(Box::new(DataType::Text)))
845        );
846        assert_eq!(
847            schema
848                .get_column_type(&["complex_table"], "metadata")
849                .unwrap(),
850            DataType::Json
851        );
852    }
853
854    // ── dialect() accessor ──────────────────────────────────────────────
855
856    #[test]
857    fn test_schema_dialect() {
858        let schema = MappingSchema::new(Dialect::Snowflake);
859        assert_eq!(schema.dialect(), Dialect::Snowflake);
860    }
861}