Skip to main content

sqlglot_rust/schema/
mod.rs

1//! Schema management system for schema-aware analysis and optimization.
2//!
3//! Provides a [`Schema`] trait and [`MappingSchema`] implementation analogous
4//! to Python sqlglot's `MappingSchema`. This is the foundation for type
5//! annotation, column qualification, projection pushdown, and lineage analysis.
6//!
7//! # Example
8//!
9//! ```rust
10//! use sqlglot_rust::schema::{MappingSchema, Schema};
11//! use sqlglot_rust::ast::DataType;
12//! use sqlglot_rust::Dialect;
13//!
14//! let mut schema = MappingSchema::new(Dialect::Ansi);
15//! schema.add_table(
16//!     &["catalog", "db", "users"],
17//!     vec![
18//!         ("id".to_string(), DataType::Int),
19//!         ("name".to_string(), DataType::Varchar(Some(255))),
20//!         ("email".to_string(), DataType::Text),
21//!     ],
22//! ).unwrap();
23//!
24//! assert_eq!(
25//!     schema.column_names(&["catalog", "db", "users"]).unwrap(),
26//!     vec!["id", "name", "email"],
27//! );
28//! assert_eq!(
29//!     schema.get_column_type(&["catalog", "db", "users"], "id").unwrap(),
30//!     DataType::Int,
31//! );
32//! assert!(schema.has_column(&["catalog", "db", "users"], "id"));
33//! ```
34
35use std::collections::HashMap;
36
37use crate::ast::DataType;
38use crate::dialects::Dialect;
39use crate::errors::SqlglotError;
40
41/// Errors specific to schema operations.
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum SchemaError {
44    /// The referenced table was not found in the schema.
45    TableNotFound(String),
46    /// The referenced column was not found in the table.
47    ColumnNotFound { table: String, column: String },
48    /// Duplicate table registration.
49    DuplicateTable(String),
50}
51
52impl std::fmt::Display for SchemaError {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        match self {
55            SchemaError::TableNotFound(t) => write!(f, "Table not found: {t}"),
56            SchemaError::ColumnNotFound { table, column } => {
57                write!(f, "Column '{column}' not found in table '{table}'")
58            }
59            SchemaError::DuplicateTable(t) => write!(f, "Table already exists: {t}"),
60        }
61    }
62}
63
64impl std::error::Error for SchemaError {}
65
66impl From<SchemaError> for SqlglotError {
67    fn from(e: SchemaError) -> Self {
68        SqlglotError::Internal(e.to_string())
69    }
70}
71
72/// Schema trait for schema-aware analysis and optimization.
73///
74/// Provides methods to query table and column metadata. Implementations
75/// can back this with in-memory mappings, database catalogs, etc.
76pub trait Schema {
77    /// Register a table with its column definitions.
78    ///
79    /// `table_path` is a slice of identifiers representing the fully qualified
80    /// table path: `[catalog, database, table]`, `[database, table]`, or `[table]`.
81    ///
82    /// # Errors
83    ///
84    /// Returns [`SchemaError::DuplicateTable`] if the table is already registered
85    /// (use `replace_table` to overwrite).
86    fn add_table(
87        &mut self,
88        table_path: &[&str],
89        columns: Vec<(String, DataType)>,
90    ) -> Result<(), SchemaError>;
91
92    /// Get the column names for a table, in definition order.
93    ///
94    /// # Errors
95    ///
96    /// Returns [`SchemaError::TableNotFound`] if the table is not registered.
97    fn column_names(&self, table_path: &[&str]) -> Result<Vec<String>, SchemaError>;
98
99    /// Get the data type of a specific column in a table.
100    ///
101    /// # Errors
102    ///
103    /// Returns [`SchemaError::TableNotFound`] or [`SchemaError::ColumnNotFound`].
104    fn get_column_type(
105        &self,
106        table_path: &[&str],
107        column: &str,
108    ) -> Result<DataType, SchemaError>;
109
110    /// Check whether a column exists in the given table.
111    fn has_column(&self, table_path: &[&str], column: &str) -> bool;
112
113    /// Get the dialect used for identifier normalization.
114    fn dialect(&self) -> Dialect;
115}
116
117/// Column metadata stored inside a [`MappingSchema`].
118#[derive(Debug, Clone, PartialEq)]
119struct ColumnInfo {
120    /// Original insertion order for stable column listing.
121    columns: Vec<(String, DataType)>,
122    /// Fast lookup by normalized column name → index into `columns`.
123    index: HashMap<String, usize>,
124}
125
126impl ColumnInfo {
127    fn new(columns: Vec<(String, DataType)>, dialect: Dialect) -> Self {
128        let index = columns
129            .iter()
130            .enumerate()
131            .map(|(i, (name, _))| (normalize_identifier(name, dialect), i))
132            .collect();
133        Self { columns, index }
134    }
135
136    fn column_names(&self) -> Vec<String> {
137        self.columns.iter().map(|(n, _)| n.clone()).collect()
138    }
139
140    fn get_type(&self, column: &str, dialect: Dialect) -> Option<&DataType> {
141        let key = normalize_identifier(column, dialect);
142        self.index.get(&key).map(|&i| &self.columns[i].1)
143    }
144
145    fn has_column(&self, column: &str, dialect: Dialect) -> bool {
146        let key = normalize_identifier(column, dialect);
147        self.index.contains_key(&key)
148    }
149}
150
151/// A schema backed by in-memory hash maps, supporting 3-level nesting:
152/// `catalog → database → table → column → type`.
153///
154/// Analogous to Python sqlglot's `MappingSchema`.
155///
156/// Identifiers are normalized according to the configured dialect:
157/// - Case-insensitive dialects (most): identifiers are lowercased for lookup.
158/// - Case-sensitive dialects (e.g. BigQuery, Hive): identifiers are kept as-is.
159/// - Quoted identifiers are always stored verbatim (not normalized).
160#[derive(Debug, Clone)]
161pub struct MappingSchema {
162    dialect: Dialect,
163    /// Nested map: catalog → database → table → ColumnInfo
164    tables: HashMap<String, HashMap<String, HashMap<String, ColumnInfo>>>,
165    /// UDF return type mappings: function_name (normalized) → return type
166    udf_types: HashMap<String, DataType>,
167}
168
169impl MappingSchema {
170    /// Create a new empty schema for the given dialect.
171    #[must_use]
172    pub fn new(dialect: Dialect) -> Self {
173        Self {
174            dialect,
175            tables: HashMap::new(),
176            udf_types: HashMap::new(),
177        }
178    }
179
180    /// Replace a table if it already exists, or add it if it doesn't.
181    pub fn replace_table(
182        &mut self,
183        table_path: &[&str],
184        columns: Vec<(String, DataType)>,
185    ) -> Result<(), SchemaError> {
186        let (catalog, database, table) = self.resolve_path(table_path)?;
187        let info = ColumnInfo::new(columns, self.dialect);
188        self.tables
189            .entry(catalog)
190            .or_default()
191            .entry(database)
192            .or_default()
193            .insert(table, info);
194        Ok(())
195    }
196
197    /// Remove a table from the schema. Returns `true` if the table existed.
198    pub fn remove_table(&mut self, table_path: &[&str]) -> Result<bool, SchemaError> {
199        let (catalog, database, table) = self.resolve_path(table_path)?;
200        let removed = self
201            .tables
202            .get_mut(&catalog)
203            .and_then(|dbs| dbs.get_mut(&database))
204            .map(|tbls| tbls.remove(&table).is_some())
205            .unwrap_or(false);
206        Ok(removed)
207    }
208
209    /// Register a UDF (user-defined function) with its return type.
210    pub fn add_udf(&mut self, name: &str, return_type: DataType) {
211        let key = normalize_identifier(name, self.dialect);
212        self.udf_types.insert(key, return_type);
213    }
214
215    /// Get the return type of a registered UDF.
216    #[must_use]
217    pub fn get_udf_type(&self, name: &str) -> Option<&DataType> {
218        let key = normalize_identifier(name, self.dialect);
219        self.udf_types.get(&key)
220    }
221
222    /// List all registered tables as `(catalog, database, table)` triples.
223    #[must_use]
224    pub fn table_names(&self) -> Vec<(String, String, String)> {
225        let mut result = Vec::new();
226        for (catalog, dbs) in &self.tables {
227            for (database, tbls) in dbs {
228                for table in tbls.keys() {
229                    result.push((catalog.clone(), database.clone(), table.clone()));
230                }
231            }
232        }
233        result
234    }
235
236    /// Find a table across all catalogs/databases when only a short path is given.
237    /// Returns the first match found (useful for unqualified table references).
238    fn find_table(&self, table_path: &[&str]) -> Option<&ColumnInfo> {
239        let (catalog, database, table) = match self.resolve_path(table_path) {
240            Ok(parts) => parts,
241            Err(_) => return None,
242        };
243
244        // Exact match first
245        if let Some(info) = self
246            .tables
247            .get(&catalog)
248            .and_then(|dbs| dbs.get(&database))
249            .and_then(|tbls| tbls.get(&table))
250        {
251            return Some(info);
252        }
253
254        // For single-name lookups, search all catalogs/databases
255        if table_path.len() == 1 {
256            let norm_name = normalize_identifier(table_path[0], self.dialect);
257            for dbs in self.tables.values() {
258                for tbls in dbs.values() {
259                    if let Some(info) = tbls.get(&norm_name) {
260                        return Some(info);
261                    }
262                }
263            }
264        }
265
266        // For 2-part lookups (db.table), search all catalogs
267        if table_path.len() == 2 {
268            let norm_db = normalize_identifier(table_path[0], self.dialect);
269            let norm_tbl = normalize_identifier(table_path[1], self.dialect);
270            for dbs in self.tables.values() {
271                if let Some(info) = dbs.get(&norm_db).and_then(|tbls| tbls.get(&norm_tbl)) {
272                    return Some(info);
273                }
274            }
275        }
276
277        None
278    }
279
280    /// Resolve a table path into normalized (catalog, database, table) parts,
281    /// filling in defaults for missing levels.
282    fn resolve_path(&self, table_path: &[&str]) -> Result<(String, String, String), SchemaError> {
283        match table_path.len() {
284            1 => Ok((
285                String::new(),
286                String::new(),
287                normalize_identifier(table_path[0], self.dialect),
288            )),
289            2 => Ok((
290                String::new(),
291                normalize_identifier(table_path[0], self.dialect),
292                normalize_identifier(table_path[1], self.dialect),
293            )),
294            3 => Ok((
295                normalize_identifier(table_path[0], self.dialect),
296                normalize_identifier(table_path[1], self.dialect),
297                normalize_identifier(table_path[2], self.dialect),
298            )),
299            _ => Err(SchemaError::TableNotFound(table_path.join("."))),
300        }
301    }
302
303    fn format_table_path(table_path: &[&str]) -> String {
304        table_path.join(".")
305    }
306}
307
308impl Schema for MappingSchema {
309    fn add_table(
310        &mut self,
311        table_path: &[&str],
312        columns: Vec<(String, DataType)>,
313    ) -> Result<(), SchemaError> {
314        let (catalog, database, table) = self.resolve_path(table_path)?;
315        let entry = self
316            .tables
317            .entry(catalog)
318            .or_default()
319            .entry(database)
320            .or_default();
321
322        if entry.contains_key(&table) {
323            return Err(SchemaError::DuplicateTable(
324                Self::format_table_path(table_path),
325            ));
326        }
327
328        let info = ColumnInfo::new(columns, self.dialect);
329        entry.insert(table, info);
330        Ok(())
331    }
332
333    fn column_names(&self, table_path: &[&str]) -> Result<Vec<String>, SchemaError> {
334        self.find_table(table_path)
335            .map(|info| info.column_names())
336            .ok_or_else(|| SchemaError::TableNotFound(Self::format_table_path(table_path)))
337    }
338
339    fn get_column_type(
340        &self,
341        table_path: &[&str],
342        column: &str,
343    ) -> Result<DataType, SchemaError> {
344        let table_str = Self::format_table_path(table_path);
345        let info = self
346            .find_table(table_path)
347            .ok_or_else(|| SchemaError::TableNotFound(table_str.clone()))?;
348
349        info.get_type(column, self.dialect)
350            .cloned()
351            .ok_or(SchemaError::ColumnNotFound {
352                table: table_str,
353                column: column.to_string(),
354            })
355    }
356
357    fn has_column(&self, table_path: &[&str], column: &str) -> bool {
358        self.find_table(table_path)
359            .is_some_and(|info| info.has_column(column, self.dialect))
360    }
361
362    fn dialect(&self) -> Dialect {
363        self.dialect
364    }
365}
366
367// ═══════════════════════════════════════════════════════════════════════
368// Identifier normalization
369// ═══════════════════════════════════════════════════════════════════════
370
371/// Normalize an identifier according to the given dialect's conventions.
372///
373/// Most SQL dialects treat unquoted identifiers as case-insensitive
374/// (typically by converting to lowercase internally). A few dialects
375/// (BigQuery, Hive, Spark, Databricks) are case-sensitive for table/column
376/// names.
377#[must_use]
378pub fn normalize_identifier(name: &str, dialect: Dialect) -> String {
379    if is_case_sensitive_dialect(dialect) {
380        name.to_string()
381    } else {
382        name.to_lowercase()
383    }
384}
385
386/// Returns `true` if the dialect treats unquoted identifiers as case-sensitive.
387#[must_use]
388pub fn is_case_sensitive_dialect(dialect: Dialect) -> bool {
389    matches!(
390        dialect,
391        Dialect::BigQuery | Dialect::Hive | Dialect::Spark | Dialect::Databricks
392    )
393}
394
395// ═══════════════════════════════════════════════════════════════════════
396// Helper: build schema from nested maps
397// ═══════════════════════════════════════════════════════════════════════
398
399/// Build a [`MappingSchema`] from a nested map structure.
400///
401/// The input maps from table name → column name → data type, mirroring
402/// the common pattern of constructing schemas from DDL or metadata.
403///
404/// # Example
405///
406/// ```rust
407/// use std::collections::HashMap;
408/// use sqlglot_rust::schema::{ensure_schema, Schema};
409/// use sqlglot_rust::ast::DataType;
410/// use sqlglot_rust::Dialect;
411///
412/// let mut tables = HashMap::new();
413/// let mut columns = HashMap::new();
414/// columns.insert("id".to_string(), DataType::Int);
415/// columns.insert("name".to_string(), DataType::Varchar(Some(255)));
416/// tables.insert("users".to_string(), columns);
417///
418/// let schema = ensure_schema(tables, Dialect::Postgres);
419/// assert!(schema.has_column(&["users"], "id"));
420/// ```
421pub fn ensure_schema(
422    tables: HashMap<String, HashMap<String, DataType>>,
423    dialect: Dialect,
424) -> MappingSchema {
425    let mut schema = MappingSchema::new(dialect);
426    for (table_name, columns) in tables {
427        let col_vec: Vec<(String, DataType)> = columns.into_iter().collect();
428        // Use replace_table to avoid DuplicateTable errors
429        let _ = schema.replace_table(&[&table_name], col_vec);
430    }
431    schema
432}
433
434/// Type alias for the 3-level nested schema map:
435/// `catalog → database → table → column → type`.
436pub type CatalogMap = HashMap<String, HashMap<String, HashMap<String, HashMap<String, DataType>>>>;
437
438/// Build a [`MappingSchema`] from a 3-level nested map:
439/// `catalog → database → table → column → type`.
440pub fn ensure_schema_nested(
441    catalog_map: CatalogMap,
442    dialect: Dialect,
443) -> MappingSchema {
444    let mut schema = MappingSchema::new(dialect);
445    for (catalog, databases) in catalog_map {
446        for (database, tables) in databases {
447            for (table, columns) in tables {
448                let col_vec: Vec<(String, DataType)> = columns.into_iter().collect();
449                let _ = schema.replace_table(
450                    &[&catalog, &database, &table],
451                    col_vec,
452                );
453            }
454        }
455    }
456    schema
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462
463    // ── Basic operations ────────────────────────────────────────────────
464
465    #[test]
466    fn test_add_and_query_table() {
467        let mut schema = MappingSchema::new(Dialect::Ansi);
468        schema
469            .add_table(
470                &["users"],
471                vec![
472                    ("id".to_string(), DataType::Int),
473                    ("name".to_string(), DataType::Varchar(Some(255))),
474                    ("email".to_string(), DataType::Text),
475                ],
476            )
477            .unwrap();
478
479        assert_eq!(
480            schema.column_names(&["users"]).unwrap(),
481            vec!["id", "name", "email"]
482        );
483        assert_eq!(
484            schema.get_column_type(&["users"], "id").unwrap(),
485            DataType::Int
486        );
487        assert_eq!(
488            schema.get_column_type(&["users"], "name").unwrap(),
489            DataType::Varchar(Some(255))
490        );
491        assert!(schema.has_column(&["users"], "id"));
492        assert!(schema.has_column(&["users"], "email"));
493        assert!(!schema.has_column(&["users"], "nonexistent"));
494    }
495
496    #[test]
497    fn test_duplicate_table_error() {
498        let mut schema = MappingSchema::new(Dialect::Ansi);
499        schema
500            .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
501            .unwrap();
502
503        let err = schema
504            .add_table(&["t"], vec![("b".to_string(), DataType::Text)])
505            .unwrap_err();
506        assert!(matches!(err, SchemaError::DuplicateTable(_)));
507    }
508
509    #[test]
510    fn test_replace_table() {
511        let mut schema = MappingSchema::new(Dialect::Ansi);
512        schema
513            .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
514            .unwrap();
515
516        schema
517            .replace_table(&["t"], vec![("b".to_string(), DataType::Text)])
518            .unwrap();
519
520        assert_eq!(schema.column_names(&["t"]).unwrap(), vec!["b"]);
521        assert_eq!(
522            schema.get_column_type(&["t"], "b").unwrap(),
523            DataType::Text
524        );
525    }
526
527    #[test]
528    fn test_remove_table() {
529        let mut schema = MappingSchema::new(Dialect::Ansi);
530        schema
531            .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
532            .unwrap();
533
534        assert!(schema.remove_table(&["t"]).unwrap());
535        assert!(!schema.remove_table(&["t"]).unwrap());
536        assert!(schema.column_names(&["t"]).is_err());
537    }
538
539    #[test]
540    fn test_table_not_found() {
541        let schema = MappingSchema::new(Dialect::Ansi);
542        let err = schema.column_names(&["nonexistent"]).unwrap_err();
543        assert!(matches!(err, SchemaError::TableNotFound(_)));
544    }
545
546    #[test]
547    fn test_column_not_found() {
548        let mut schema = MappingSchema::new(Dialect::Ansi);
549        schema
550            .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
551            .unwrap();
552
553        let err = schema.get_column_type(&["t"], "z").unwrap_err();
554        assert!(matches!(err, SchemaError::ColumnNotFound { .. }));
555    }
556
557    // ── Multi-level nesting ─────────────────────────────────────────────
558
559    #[test]
560    fn test_three_level_path() {
561        let mut schema = MappingSchema::new(Dialect::Ansi);
562        schema
563            .add_table(
564                &["my_catalog", "my_db", "orders"],
565                vec![
566                    ("order_id".to_string(), DataType::BigInt),
567                    ("total".to_string(), DataType::Decimal {
568                        precision: Some(10),
569                        scale: Some(2),
570                    }),
571                ],
572            )
573            .unwrap();
574
575        assert_eq!(
576            schema
577                .column_names(&["my_catalog", "my_db", "orders"])
578                .unwrap(),
579            vec!["order_id", "total"]
580        );
581        assert!(schema.has_column(&["my_catalog", "my_db", "orders"], "order_id"));
582    }
583
584    #[test]
585    fn test_two_level_path() {
586        let mut schema = MappingSchema::new(Dialect::Ansi);
587        schema
588            .add_table(
589                &["public", "users"],
590                vec![("id".to_string(), DataType::Int)],
591            )
592            .unwrap();
593
594        assert_eq!(
595            schema.column_names(&["public", "users"]).unwrap(),
596            vec!["id"]
597        );
598    }
599
600    #[test]
601    fn test_short_path_searches_all() {
602        let mut schema = MappingSchema::new(Dialect::Ansi);
603        schema
604            .add_table(
605                &["catalog", "db", "orders"],
606                vec![("id".to_string(), DataType::Int)],
607            )
608            .unwrap();
609
610        // Single-name lookup should find it
611        assert!(schema.has_column(&["orders"], "id"));
612        assert_eq!(schema.column_names(&["orders"]).unwrap(), vec!["id"]);
613
614        // Two-part lookup should find it
615        assert!(schema.has_column(&["db", "orders"], "id"));
616    }
617
618    // ── Dialect-aware normalization ─────────────────────────────────────
619
620    #[test]
621    fn test_case_insensitive_dialect() {
622        let mut schema = MappingSchema::new(Dialect::Postgres);
623        schema
624            .add_table(
625                &["Users"],
626                vec![("ID".to_string(), DataType::Int)],
627            )
628            .unwrap();
629
630        // Lookups should be case-insensitive
631        assert!(schema.has_column(&["users"], "id"));
632        assert!(schema.has_column(&["USERS"], "ID"));
633        assert!(schema.has_column(&["Users"], "Id"));
634        assert_eq!(
635            schema.get_column_type(&["users"], "id").unwrap(),
636            DataType::Int
637        );
638    }
639
640    #[test]
641    fn test_case_sensitive_dialect() {
642        let mut schema = MappingSchema::new(Dialect::BigQuery);
643        schema
644            .add_table(
645                &["Users"],
646                vec![("ID".to_string(), DataType::Int)],
647            )
648            .unwrap();
649
650        // BigQuery is case-sensitive
651        assert!(schema.has_column(&["Users"], "ID"));
652        assert!(!schema.has_column(&["users"], "ID"));
653        assert!(!schema.has_column(&["Users"], "id"));
654    }
655
656    #[test]
657    fn test_hive_case_sensitive() {
658        let mut schema = MappingSchema::new(Dialect::Hive);
659        schema
660            .add_table(
661                &["MyTable"],
662                vec![("Col1".to_string(), DataType::Text)],
663            )
664            .unwrap();
665
666        assert!(schema.has_column(&["MyTable"], "Col1"));
667        assert!(!schema.has_column(&["mytable"], "col1"));
668    }
669
670    // ── UDF return types ────────────────────────────────────────────────
671
672    #[test]
673    fn test_udf_return_type() {
674        let mut schema = MappingSchema::new(Dialect::Ansi);
675        schema.add_udf("my_custom_fn", DataType::Int);
676
677        assert_eq!(
678            schema.get_udf_type("my_custom_fn").unwrap(),
679            &DataType::Int
680        );
681        // Case-insensitive for ANSI
682        assert_eq!(
683            schema.get_udf_type("MY_CUSTOM_FN").unwrap(),
684            &DataType::Int
685        );
686        assert!(schema.get_udf_type("nonexistent").is_none());
687    }
688
689    #[test]
690    fn test_udf_case_sensitive() {
691        let mut schema = MappingSchema::new(Dialect::BigQuery);
692        schema.add_udf("myFunc", DataType::Boolean);
693
694        assert!(schema.get_udf_type("myFunc").is_some());
695        assert!(schema.get_udf_type("MYFUNC").is_none());
696    }
697
698    // ── ensure_schema helpers ───────────────────────────────────────────
699
700    #[test]
701    fn test_ensure_schema() {
702        let mut tables = HashMap::new();
703        let mut cols = HashMap::new();
704        cols.insert("id".to_string(), DataType::Int);
705        cols.insert("name".to_string(), DataType::Text);
706        tables.insert("users".to_string(), cols);
707
708        let schema = ensure_schema(tables, Dialect::Postgres);
709        assert!(schema.has_column(&["users"], "id"));
710        assert!(schema.has_column(&["users"], "name"));
711    }
712
713    #[test]
714    fn test_ensure_schema_nested() {
715        let mut catalogs = HashMap::new();
716        let mut databases = HashMap::new();
717        let mut tables = HashMap::new();
718        let mut cols = HashMap::new();
719        cols.insert("order_id".to_string(), DataType::BigInt);
720        tables.insert("orders".to_string(), cols);
721        databases.insert("sales".to_string(), tables);
722        catalogs.insert("warehouse".to_string(), databases);
723
724        let schema = ensure_schema_nested(catalogs, Dialect::Ansi);
725        assert!(schema.has_column(&["warehouse", "sales", "orders"], "order_id"));
726        // Short-path lookup
727        assert!(schema.has_column(&["orders"], "order_id"));
728    }
729
730    // ── table_names listing ─────────────────────────────────────────────
731
732    #[test]
733    fn test_table_names() {
734        let mut schema = MappingSchema::new(Dialect::Ansi);
735        schema
736            .add_table(
737                &["cat", "db", "t1"],
738                vec![("a".to_string(), DataType::Int)],
739            )
740            .unwrap();
741        schema
742            .add_table(
743                &["cat", "db", "t2"],
744                vec![("b".to_string(), DataType::Int)],
745            )
746            .unwrap();
747
748        let mut names = schema.table_names();
749        names.sort();
750        assert_eq!(names.len(), 2);
751        assert!(names.iter().any(|(c, d, t)| c == "cat" && d == "db" && t == "t1"));
752        assert!(names.iter().any(|(c, d, t)| c == "cat" && d == "db" && t == "t2"));
753    }
754
755    // ── Invalid path ────────────────────────────────────────────────────
756
757    #[test]
758    fn test_invalid_path_too_many_parts() {
759        let mut schema = MappingSchema::new(Dialect::Ansi);
760        let err = schema
761            .add_table(
762                &["a", "b", "c", "d"],
763                vec![("x".to_string(), DataType::Int)],
764            )
765            .unwrap_err();
766        assert!(matches!(err, SchemaError::TableNotFound(_)));
767    }
768
769    #[test]
770    fn test_empty_schema_has_no_columns() {
771        let schema = MappingSchema::new(Dialect::Ansi);
772        assert!(!schema.has_column(&["any_table"], "any_col"));
773    }
774
775    // ── Schema error display ────────────────────────────────────────────
776
777    #[test]
778    fn test_schema_error_display() {
779        let e = SchemaError::TableNotFound("users".to_string());
780        assert_eq!(e.to_string(), "Table not found: users");
781
782        let e = SchemaError::ColumnNotFound {
783            table: "users".to_string(),
784            column: "age".to_string(),
785        };
786        assert_eq!(e.to_string(), "Column 'age' not found in table 'users'");
787
788        let e = SchemaError::DuplicateTable("users".to_string());
789        assert_eq!(e.to_string(), "Table already exists: users");
790    }
791
792    // ── SchemaError → SqlglotError conversion ───────────────────────────
793
794    #[test]
795    fn test_schema_error_into_sqlglot_error() {
796        let e: SqlglotError = SchemaError::TableNotFound("t".to_string()).into();
797        assert!(matches!(e, SqlglotError::Internal(_)));
798    }
799
800    // ── Multiple dialects ───────────────────────────────────────────────
801
802    #[test]
803    fn test_multiple_dialects_normalization() {
804        // Postgres: case-insensitive
805        let mut pg = MappingSchema::new(Dialect::Postgres);
806        pg.add_table(&["T"], vec![("C".to_string(), DataType::Int)])
807            .unwrap();
808        assert!(pg.has_column(&["t"], "c"));
809
810        // MySQL: case-insensitive
811        let mut my = MappingSchema::new(Dialect::Mysql);
812        my.add_table(&["T"], vec![("C".to_string(), DataType::Int)])
813            .unwrap();
814        assert!(my.has_column(&["t"], "c"));
815
816        // Spark: case-sensitive
817        let mut sp = MappingSchema::new(Dialect::Spark);
818        sp.add_table(&["T"], vec![("C".to_string(), DataType::Int)])
819            .unwrap();
820        assert!(!sp.has_column(&["t"], "c"));
821        assert!(sp.has_column(&["T"], "C"));
822    }
823
824    // ── Complex data types ──────────────────────────────────────────────
825
826    #[test]
827    fn test_complex_data_types() {
828        let mut schema = MappingSchema::new(Dialect::Ansi);
829        schema
830            .add_table(
831                &["complex_table"],
832                vec![
833                    ("tags".to_string(), DataType::Array(Some(Box::new(DataType::Text)))),
834                    ("metadata".to_string(), DataType::Json),
835                    ("coords".to_string(), DataType::Struct(vec![
836                        ("lat".to_string(), DataType::Double),
837                        ("lng".to_string(), DataType::Double),
838                    ])),
839                    ("lookup".to_string(), DataType::Map {
840                        key: Box::new(DataType::Text),
841                        value: Box::new(DataType::Int),
842                    }),
843                ],
844            )
845            .unwrap();
846
847        assert_eq!(
848            schema.get_column_type(&["complex_table"], "tags").unwrap(),
849            DataType::Array(Some(Box::new(DataType::Text)))
850        );
851        assert_eq!(
852            schema.get_column_type(&["complex_table"], "metadata").unwrap(),
853            DataType::Json
854        );
855    }
856
857    // ── dialect() accessor ──────────────────────────────────────────────
858
859    #[test]
860    fn test_schema_dialect() {
861        let schema = MappingSchema::new(Dialect::Snowflake);
862        assert_eq!(schema.dialect(), Dialect::Snowflake);
863    }
864}