Skip to main content

sqlglot_rust/schema/
mod.rs

1//! Schema management system for schema-aware analysis and optimization.
2//!
3//! Provides a [`Schema`] trait and [`MappingSchema`] implementation analogous
4//! to Python sqlglot's `MappingSchema`. This is the foundation for type
5//! annotation, column qualification, projection pushdown, and lineage analysis.
6//!
7//! # Example
8//!
9//! ```rust
10//! use sqlglot_rust::schema::{MappingSchema, Schema};
11//! use sqlglot_rust::ast::DataType;
12//! use sqlglot_rust::Dialect;
13//!
14//! let mut schema = MappingSchema::new(Dialect::Ansi);
15//! schema.add_table(
16//!     &["catalog", "db", "users"],
17//!     vec![
18//!         ("id".to_string(), DataType::Int),
19//!         ("name".to_string(), DataType::Varchar(Some(255))),
20//!         ("email".to_string(), DataType::Text),
21//!     ],
22//! ).unwrap();
23//!
24//! assert_eq!(
25//!     schema.column_names(&["catalog", "db", "users"]).unwrap(),
26//!     vec!["id", "name", "email"],
27//! );
28//! assert_eq!(
29//!     schema.get_column_type(&["catalog", "db", "users"], "id").unwrap(),
30//!     DataType::Int,
31//! );
32//! assert!(schema.has_column(&["catalog", "db", "users"], "id"));
33//! ```
34
35use std::collections::HashMap;
36
37use crate::ast::DataType;
38use crate::dialects::Dialect;
39use crate::errors::SqlglotError;
40
41/// Errors specific to schema operations.
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum SchemaError {
44    /// The referenced table was not found in the schema.
45    TableNotFound(String),
46    /// The referenced column was not found in the table.
47    ColumnNotFound { table: String, column: String },
48    /// Duplicate table registration.
49    DuplicateTable(String),
50}
51
52impl std::fmt::Display for SchemaError {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        match self {
55            SchemaError::TableNotFound(t) => write!(f, "Table not found: {t}"),
56            SchemaError::ColumnNotFound { table, column } => {
57                write!(f, "Column '{column}' not found in table '{table}'")
58            }
59            SchemaError::DuplicateTable(t) => write!(f, "Table already exists: {t}"),
60        }
61    }
62}
63
64impl std::error::Error for SchemaError {}
65
66impl From<SchemaError> for SqlglotError {
67    fn from(e: SchemaError) -> Self {
68        SqlglotError::Internal(e.to_string())
69    }
70}
71
72/// Schema trait for schema-aware analysis and optimization.
73///
74/// Provides methods to query table and column metadata. Implementations
75/// can back this with in-memory mappings, database catalogs, etc.
76pub trait Schema {
77    /// Register a table with its column definitions.
78    ///
79    /// `table_path` is a slice of identifiers representing the fully qualified
80    /// table path: `[catalog, database, table]`, `[database, table]`, or `[table]`.
81    ///
82    /// # Errors
83    ///
84    /// Returns [`SchemaError::DuplicateTable`] if the table is already registered
85    /// (use `replace_table` to overwrite).
86    fn add_table(
87        &mut self,
88        table_path: &[&str],
89        columns: Vec<(String, DataType)>,
90    ) -> Result<(), SchemaError>;
91
92    /// Get the column names for a table, in definition order.
93    ///
94    /// # Errors
95    ///
96    /// Returns [`SchemaError::TableNotFound`] if the table is not registered.
97    fn column_names(&self, table_path: &[&str]) -> Result<Vec<String>, SchemaError>;
98
99    /// Get the data type of a specific column in a table.
100    ///
101    /// # Errors
102    ///
103    /// Returns [`SchemaError::TableNotFound`] or [`SchemaError::ColumnNotFound`].
104    fn get_column_type(&self, table_path: &[&str], column: &str) -> Result<DataType, SchemaError>;
105
106    /// Check whether a column exists in the given table.
107    fn has_column(&self, table_path: &[&str], column: &str) -> bool;
108
109    /// Get the dialect used for identifier normalization.
110    fn dialect(&self) -> Dialect;
111}
112
113/// Column metadata stored inside a [`MappingSchema`].
114#[derive(Debug, Clone, PartialEq)]
115struct ColumnInfo {
116    /// Original insertion order for stable column listing.
117    columns: Vec<(String, DataType)>,
118    /// Fast lookup by normalized column name → index into `columns`.
119    index: HashMap<String, usize>,
120}
121
122impl ColumnInfo {
123    fn new(columns: Vec<(String, DataType)>, dialect: Dialect) -> Self {
124        let index = columns
125            .iter()
126            .enumerate()
127            .map(|(i, (name, _))| (normalize_identifier(name, dialect), i))
128            .collect();
129        Self { columns, index }
130    }
131
132    fn column_names(&self) -> Vec<String> {
133        self.columns.iter().map(|(n, _)| n.clone()).collect()
134    }
135
136    fn get_type(&self, column: &str, dialect: Dialect) -> Option<&DataType> {
137        let key = normalize_identifier(column, dialect);
138        self.index.get(&key).map(|&i| &self.columns[i].1)
139    }
140
141    fn has_column(&self, column: &str, dialect: Dialect) -> bool {
142        let key = normalize_identifier(column, dialect);
143        self.index.contains_key(&key)
144    }
145}
146
147/// A schema backed by in-memory hash maps, supporting 3-level nesting:
148/// `catalog → database → table → column → type`.
149///
150/// Analogous to Python sqlglot's `MappingSchema`.
151///
152/// Identifiers are normalized according to the configured dialect:
153/// - Case-insensitive dialects (most): identifiers are lowercased for lookup.
154/// - Case-sensitive dialects (e.g. BigQuery, Hive): identifiers are kept as-is.
155/// - Quoted identifiers are always stored verbatim (not normalized).
156#[derive(Debug, Clone)]
157pub struct MappingSchema {
158    dialect: Dialect,
159    /// Nested map: catalog → database → table → ColumnInfo
160    tables: HashMap<String, HashMap<String, HashMap<String, ColumnInfo>>>,
161    /// UDF return type mappings: function_name (normalized) → return type
162    udf_types: HashMap<String, DataType>,
163}
164
165impl MappingSchema {
166    /// Create a new empty schema for the given dialect.
167    #[must_use]
168    pub fn new(dialect: Dialect) -> Self {
169        Self {
170            dialect,
171            tables: HashMap::new(),
172            udf_types: HashMap::new(),
173        }
174    }
175
176    /// Replace a table if it already exists, or add it if it doesn't.
177    pub fn replace_table(
178        &mut self,
179        table_path: &[&str],
180        columns: Vec<(String, DataType)>,
181    ) -> Result<(), SchemaError> {
182        let (catalog, database, table) = self.resolve_path(table_path)?;
183        let info = ColumnInfo::new(columns, self.dialect);
184        self.tables
185            .entry(catalog)
186            .or_default()
187            .entry(database)
188            .or_default()
189            .insert(table, info);
190        Ok(())
191    }
192
193    /// Remove a table from the schema. Returns `true` if the table existed.
194    pub fn remove_table(&mut self, table_path: &[&str]) -> Result<bool, SchemaError> {
195        let (catalog, database, table) = self.resolve_path(table_path)?;
196        let removed = self
197            .tables
198            .get_mut(&catalog)
199            .and_then(|dbs| dbs.get_mut(&database))
200            .map(|tbls| tbls.remove(&table).is_some())
201            .unwrap_or(false);
202        Ok(removed)
203    }
204
205    /// Register a UDF (user-defined function) with its return type.
206    pub fn add_udf(&mut self, name: &str, return_type: DataType) {
207        let key = normalize_identifier(name, self.dialect);
208        self.udf_types.insert(key, return_type);
209    }
210
211    /// Get the return type of a registered UDF.
212    #[must_use]
213    pub fn get_udf_type(&self, name: &str) -> Option<&DataType> {
214        let key = normalize_identifier(name, self.dialect);
215        self.udf_types.get(&key)
216    }
217
218    /// List all registered tables as `(catalog, database, table)` triples.
219    #[must_use]
220    pub fn table_names(&self) -> Vec<(String, String, String)> {
221        let mut result = Vec::new();
222        for (catalog, dbs) in &self.tables {
223            for (database, tbls) in dbs {
224                for table in tbls.keys() {
225                    result.push((catalog.clone(), database.clone(), table.clone()));
226                }
227            }
228        }
229        result
230    }
231
232    /// Find a table across all catalogs/databases when only a short path is given.
233    /// Returns the first match found (useful for unqualified table references).
234    fn find_table(&self, table_path: &[&str]) -> Option<&ColumnInfo> {
235        let (catalog, database, table) = match self.resolve_path(table_path) {
236            Ok(parts) => parts,
237            Err(_) => return None,
238        };
239
240        // Exact match first
241        if let Some(info) = self
242            .tables
243            .get(&catalog)
244            .and_then(|dbs| dbs.get(&database))
245            .and_then(|tbls| tbls.get(&table))
246        {
247            return Some(info);
248        }
249
250        // For single-name lookups, search all catalogs/databases
251        if table_path.len() == 1 {
252            let norm_name = normalize_identifier(table_path[0], self.dialect);
253            for dbs in self.tables.values() {
254                for tbls in dbs.values() {
255                    if let Some(info) = tbls.get(&norm_name) {
256                        return Some(info);
257                    }
258                }
259            }
260        }
261
262        // For 2-part lookups (db.table), search all catalogs
263        if table_path.len() == 2 {
264            let norm_db = normalize_identifier(table_path[0], self.dialect);
265            let norm_tbl = normalize_identifier(table_path[1], self.dialect);
266            for dbs in self.tables.values() {
267                if let Some(info) = dbs.get(&norm_db).and_then(|tbls| tbls.get(&norm_tbl)) {
268                    return Some(info);
269                }
270            }
271        }
272
273        None
274    }
275
276    /// Resolve a table path into normalized (catalog, database, table) parts,
277    /// filling in defaults for missing levels.
278    fn resolve_path(&self, table_path: &[&str]) -> Result<(String, String, String), SchemaError> {
279        match table_path.len() {
280            1 => Ok((
281                String::new(),
282                String::new(),
283                normalize_identifier(table_path[0], self.dialect),
284            )),
285            2 => Ok((
286                String::new(),
287                normalize_identifier(table_path[0], self.dialect),
288                normalize_identifier(table_path[1], self.dialect),
289            )),
290            3 => Ok((
291                normalize_identifier(table_path[0], self.dialect),
292                normalize_identifier(table_path[1], self.dialect),
293                normalize_identifier(table_path[2], self.dialect),
294            )),
295            _ => Err(SchemaError::TableNotFound(table_path.join("."))),
296        }
297    }
298
299    fn format_table_path(table_path: &[&str]) -> String {
300        table_path.join(".")
301    }
302}
303
304impl Schema for MappingSchema {
305    fn add_table(
306        &mut self,
307        table_path: &[&str],
308        columns: Vec<(String, DataType)>,
309    ) -> Result<(), SchemaError> {
310        let (catalog, database, table) = self.resolve_path(table_path)?;
311        let entry = self
312            .tables
313            .entry(catalog)
314            .or_default()
315            .entry(database)
316            .or_default();
317
318        if entry.contains_key(&table) {
319            return Err(SchemaError::DuplicateTable(Self::format_table_path(
320                table_path,
321            )));
322        }
323
324        let info = ColumnInfo::new(columns, self.dialect);
325        entry.insert(table, info);
326        Ok(())
327    }
328
329    fn column_names(&self, table_path: &[&str]) -> Result<Vec<String>, SchemaError> {
330        self.find_table(table_path)
331            .map(|info| info.column_names())
332            .ok_or_else(|| SchemaError::TableNotFound(Self::format_table_path(table_path)))
333    }
334
335    fn get_column_type(&self, table_path: &[&str], column: &str) -> Result<DataType, SchemaError> {
336        let table_str = Self::format_table_path(table_path);
337        let info = self
338            .find_table(table_path)
339            .ok_or_else(|| SchemaError::TableNotFound(table_str.clone()))?;
340
341        info.get_type(column, self.dialect)
342            .cloned()
343            .ok_or(SchemaError::ColumnNotFound {
344                table: table_str,
345                column: column.to_string(),
346            })
347    }
348
349    fn has_column(&self, table_path: &[&str], column: &str) -> bool {
350        self.find_table(table_path)
351            .is_some_and(|info| info.has_column(column, self.dialect))
352    }
353
354    fn dialect(&self) -> Dialect {
355        self.dialect
356    }
357}
358
359// ═══════════════════════════════════════════════════════════════════════
360// Identifier normalization
361// ═══════════════════════════════════════════════════════════════════════
362
363/// Normalize an identifier according to the given dialect's conventions.
364///
365/// Most SQL dialects treat unquoted identifiers as case-insensitive
366/// (typically by converting to lowercase internally). A few dialects
367/// (BigQuery, Hive, Spark, Databricks) are case-sensitive for table/column
368/// names.
369#[must_use]
370pub fn normalize_identifier(name: &str, dialect: Dialect) -> String {
371    if is_case_sensitive_dialect(dialect) {
372        name.to_string()
373    } else {
374        name.to_lowercase()
375    }
376}
377
378/// Returns `true` if the dialect treats unquoted identifiers as case-sensitive.
379#[must_use]
380pub fn is_case_sensitive_dialect(dialect: Dialect) -> bool {
381    matches!(
382        dialect,
383        Dialect::BigQuery | Dialect::Hive | Dialect::Spark | Dialect::Databricks
384    )
385}
386
387// ═══════════════════════════════════════════════════════════════════════
388// Helper: build schema from nested maps
389// ═══════════════════════════════════════════════════════════════════════
390
391/// Build a [`MappingSchema`] from a nested map structure.
392///
393/// The input maps from table name → column name → data type, mirroring
394/// the common pattern of constructing schemas from DDL or metadata.
395///
396/// # Example
397///
398/// ```rust
399/// use std::collections::HashMap;
400/// use sqlglot_rust::schema::{ensure_schema, Schema};
401/// use sqlglot_rust::ast::DataType;
402/// use sqlglot_rust::Dialect;
403///
404/// let mut tables = HashMap::new();
405/// let mut columns = HashMap::new();
406/// columns.insert("id".to_string(), DataType::Int);
407/// columns.insert("name".to_string(), DataType::Varchar(Some(255)));
408/// tables.insert("users".to_string(), columns);
409///
410/// let schema = ensure_schema(tables, Dialect::Postgres);
411/// assert!(schema.has_column(&["users"], "id"));
412/// ```
413pub fn ensure_schema(
414    tables: HashMap<String, HashMap<String, DataType>>,
415    dialect: Dialect,
416) -> MappingSchema {
417    let mut schema = MappingSchema::new(dialect);
418    for (table_name, columns) in tables {
419        let col_vec: Vec<(String, DataType)> = columns.into_iter().collect();
420        // Use replace_table to avoid DuplicateTable errors
421        let _ = schema.replace_table(&[&table_name], col_vec);
422    }
423    schema
424}
425
426/// Type alias for the 3-level nested schema map:
427/// `catalog → database → table → column → type`.
428pub type CatalogMap = HashMap<String, HashMap<String, HashMap<String, HashMap<String, DataType>>>>;
429
430/// Build a [`MappingSchema`] from a 3-level nested map:
431/// `catalog → database → table → column → type`.
432pub fn ensure_schema_nested(catalog_map: CatalogMap, dialect: Dialect) -> MappingSchema {
433    let mut schema = MappingSchema::new(dialect);
434    for (catalog, databases) in catalog_map {
435        for (database, tables) in databases {
436            for (table, columns) in tables {
437                let col_vec: Vec<(String, DataType)> = columns.into_iter().collect();
438                let _ = schema.replace_table(&[&catalog, &database, &table], col_vec);
439            }
440        }
441    }
442    schema
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448
449    // ── Basic operations ────────────────────────────────────────────────
450
451    #[test]
452    fn test_add_and_query_table() {
453        let mut schema = MappingSchema::new(Dialect::Ansi);
454        schema
455            .add_table(
456                &["users"],
457                vec![
458                    ("id".to_string(), DataType::Int),
459                    ("name".to_string(), DataType::Varchar(Some(255))),
460                    ("email".to_string(), DataType::Text),
461                ],
462            )
463            .unwrap();
464
465        assert_eq!(
466            schema.column_names(&["users"]).unwrap(),
467            vec!["id", "name", "email"]
468        );
469        assert_eq!(
470            schema.get_column_type(&["users"], "id").unwrap(),
471            DataType::Int
472        );
473        assert_eq!(
474            schema.get_column_type(&["users"], "name").unwrap(),
475            DataType::Varchar(Some(255))
476        );
477        assert!(schema.has_column(&["users"], "id"));
478        assert!(schema.has_column(&["users"], "email"));
479        assert!(!schema.has_column(&["users"], "nonexistent"));
480    }
481
482    #[test]
483    fn test_duplicate_table_error() {
484        let mut schema = MappingSchema::new(Dialect::Ansi);
485        schema
486            .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
487            .unwrap();
488
489        let err = schema
490            .add_table(&["t"], vec![("b".to_string(), DataType::Text)])
491            .unwrap_err();
492        assert!(matches!(err, SchemaError::DuplicateTable(_)));
493    }
494
495    #[test]
496    fn test_replace_table() {
497        let mut schema = MappingSchema::new(Dialect::Ansi);
498        schema
499            .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
500            .unwrap();
501
502        schema
503            .replace_table(&["t"], vec![("b".to_string(), DataType::Text)])
504            .unwrap();
505
506        assert_eq!(schema.column_names(&["t"]).unwrap(), vec!["b"]);
507        assert_eq!(schema.get_column_type(&["t"], "b").unwrap(), DataType::Text);
508    }
509
510    #[test]
511    fn test_remove_table() {
512        let mut schema = MappingSchema::new(Dialect::Ansi);
513        schema
514            .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
515            .unwrap();
516
517        assert!(schema.remove_table(&["t"]).unwrap());
518        assert!(!schema.remove_table(&["t"]).unwrap());
519        assert!(schema.column_names(&["t"]).is_err());
520    }
521
522    #[test]
523    fn test_table_not_found() {
524        let schema = MappingSchema::new(Dialect::Ansi);
525        let err = schema.column_names(&["nonexistent"]).unwrap_err();
526        assert!(matches!(err, SchemaError::TableNotFound(_)));
527    }
528
529    #[test]
530    fn test_column_not_found() {
531        let mut schema = MappingSchema::new(Dialect::Ansi);
532        schema
533            .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
534            .unwrap();
535
536        let err = schema.get_column_type(&["t"], "z").unwrap_err();
537        assert!(matches!(err, SchemaError::ColumnNotFound { .. }));
538    }
539
540    // ── Multi-level nesting ─────────────────────────────────────────────
541
542    #[test]
543    fn test_three_level_path() {
544        let mut schema = MappingSchema::new(Dialect::Ansi);
545        schema
546            .add_table(
547                &["my_catalog", "my_db", "orders"],
548                vec![
549                    ("order_id".to_string(), DataType::BigInt),
550                    (
551                        "total".to_string(),
552                        DataType::Decimal {
553                            precision: Some(10),
554                            scale: Some(2),
555                        },
556                    ),
557                ],
558            )
559            .unwrap();
560
561        assert_eq!(
562            schema
563                .column_names(&["my_catalog", "my_db", "orders"])
564                .unwrap(),
565            vec!["order_id", "total"]
566        );
567        assert!(schema.has_column(&["my_catalog", "my_db", "orders"], "order_id"));
568    }
569
570    #[test]
571    fn test_two_level_path() {
572        let mut schema = MappingSchema::new(Dialect::Ansi);
573        schema
574            .add_table(
575                &["public", "users"],
576                vec![("id".to_string(), DataType::Int)],
577            )
578            .unwrap();
579
580        assert_eq!(
581            schema.column_names(&["public", "users"]).unwrap(),
582            vec!["id"]
583        );
584    }
585
586    #[test]
587    fn test_short_path_searches_all() {
588        let mut schema = MappingSchema::new(Dialect::Ansi);
589        schema
590            .add_table(
591                &["catalog", "db", "orders"],
592                vec![("id".to_string(), DataType::Int)],
593            )
594            .unwrap();
595
596        // Single-name lookup should find it
597        assert!(schema.has_column(&["orders"], "id"));
598        assert_eq!(schema.column_names(&["orders"]).unwrap(), vec!["id"]);
599
600        // Two-part lookup should find it
601        assert!(schema.has_column(&["db", "orders"], "id"));
602    }
603
604    // ── Dialect-aware normalization ─────────────────────────────────────
605
606    #[test]
607    fn test_case_insensitive_dialect() {
608        let mut schema = MappingSchema::new(Dialect::Postgres);
609        schema
610            .add_table(&["Users"], vec![("ID".to_string(), DataType::Int)])
611            .unwrap();
612
613        // Lookups should be case-insensitive
614        assert!(schema.has_column(&["users"], "id"));
615        assert!(schema.has_column(&["USERS"], "ID"));
616        assert!(schema.has_column(&["Users"], "Id"));
617        assert_eq!(
618            schema.get_column_type(&["users"], "id").unwrap(),
619            DataType::Int
620        );
621    }
622
623    #[test]
624    fn test_case_sensitive_dialect() {
625        let mut schema = MappingSchema::new(Dialect::BigQuery);
626        schema
627            .add_table(&["Users"], vec![("ID".to_string(), DataType::Int)])
628            .unwrap();
629
630        // BigQuery is case-sensitive
631        assert!(schema.has_column(&["Users"], "ID"));
632        assert!(!schema.has_column(&["users"], "ID"));
633        assert!(!schema.has_column(&["Users"], "id"));
634    }
635
636    #[test]
637    fn test_hive_case_sensitive() {
638        let mut schema = MappingSchema::new(Dialect::Hive);
639        schema
640            .add_table(&["MyTable"], vec![("Col1".to_string(), DataType::Text)])
641            .unwrap();
642
643        assert!(schema.has_column(&["MyTable"], "Col1"));
644        assert!(!schema.has_column(&["mytable"], "col1"));
645    }
646
647    // ── UDF return types ────────────────────────────────────────────────
648
649    #[test]
650    fn test_udf_return_type() {
651        let mut schema = MappingSchema::new(Dialect::Ansi);
652        schema.add_udf("my_custom_fn", DataType::Int);
653
654        assert_eq!(schema.get_udf_type("my_custom_fn").unwrap(), &DataType::Int);
655        // Case-insensitive for ANSI
656        assert_eq!(schema.get_udf_type("MY_CUSTOM_FN").unwrap(), &DataType::Int);
657        assert!(schema.get_udf_type("nonexistent").is_none());
658    }
659
660    #[test]
661    fn test_udf_case_sensitive() {
662        let mut schema = MappingSchema::new(Dialect::BigQuery);
663        schema.add_udf("myFunc", DataType::Boolean);
664
665        assert!(schema.get_udf_type("myFunc").is_some());
666        assert!(schema.get_udf_type("MYFUNC").is_none());
667    }
668
669    // ── ensure_schema helpers ───────────────────────────────────────────
670
671    #[test]
672    fn test_ensure_schema() {
673        let mut tables = HashMap::new();
674        let mut cols = HashMap::new();
675        cols.insert("id".to_string(), DataType::Int);
676        cols.insert("name".to_string(), DataType::Text);
677        tables.insert("users".to_string(), cols);
678
679        let schema = ensure_schema(tables, Dialect::Postgres);
680        assert!(schema.has_column(&["users"], "id"));
681        assert!(schema.has_column(&["users"], "name"));
682    }
683
684    #[test]
685    fn test_ensure_schema_nested() {
686        let mut catalogs = HashMap::new();
687        let mut databases = HashMap::new();
688        let mut tables = HashMap::new();
689        let mut cols = HashMap::new();
690        cols.insert("order_id".to_string(), DataType::BigInt);
691        tables.insert("orders".to_string(), cols);
692        databases.insert("sales".to_string(), tables);
693        catalogs.insert("warehouse".to_string(), databases);
694
695        let schema = ensure_schema_nested(catalogs, Dialect::Ansi);
696        assert!(schema.has_column(&["warehouse", "sales", "orders"], "order_id"));
697        // Short-path lookup
698        assert!(schema.has_column(&["orders"], "order_id"));
699    }
700
701    // ── table_names listing ─────────────────────────────────────────────
702
703    #[test]
704    fn test_table_names() {
705        let mut schema = MappingSchema::new(Dialect::Ansi);
706        schema
707            .add_table(&["cat", "db", "t1"], vec![("a".to_string(), DataType::Int)])
708            .unwrap();
709        schema
710            .add_table(&["cat", "db", "t2"], vec![("b".to_string(), DataType::Int)])
711            .unwrap();
712
713        let mut names = schema.table_names();
714        names.sort();
715        assert_eq!(names.len(), 2);
716        assert!(
717            names
718                .iter()
719                .any(|(c, d, t)| c == "cat" && d == "db" && t == "t1")
720        );
721        assert!(
722            names
723                .iter()
724                .any(|(c, d, t)| c == "cat" && d == "db" && t == "t2")
725        );
726    }
727
728    // ── Invalid path ────────────────────────────────────────────────────
729
730    #[test]
731    fn test_invalid_path_too_many_parts() {
732        let mut schema = MappingSchema::new(Dialect::Ansi);
733        let err = schema
734            .add_table(
735                &["a", "b", "c", "d"],
736                vec![("x".to_string(), DataType::Int)],
737            )
738            .unwrap_err();
739        assert!(matches!(err, SchemaError::TableNotFound(_)));
740    }
741
742    #[test]
743    fn test_empty_schema_has_no_columns() {
744        let schema = MappingSchema::new(Dialect::Ansi);
745        assert!(!schema.has_column(&["any_table"], "any_col"));
746    }
747
748    // ── Schema error display ────────────────────────────────────────────
749
750    #[test]
751    fn test_schema_error_display() {
752        let e = SchemaError::TableNotFound("users".to_string());
753        assert_eq!(e.to_string(), "Table not found: users");
754
755        let e = SchemaError::ColumnNotFound {
756            table: "users".to_string(),
757            column: "age".to_string(),
758        };
759        assert_eq!(e.to_string(), "Column 'age' not found in table 'users'");
760
761        let e = SchemaError::DuplicateTable("users".to_string());
762        assert_eq!(e.to_string(), "Table already exists: users");
763    }
764
765    // ── SchemaError → SqlglotError conversion ───────────────────────────
766
767    #[test]
768    fn test_schema_error_into_sqlglot_error() {
769        let e: SqlglotError = SchemaError::TableNotFound("t".to_string()).into();
770        assert!(matches!(e, SqlglotError::Internal(_)));
771    }
772
773    // ── Multiple dialects ───────────────────────────────────────────────
774
775    #[test]
776    fn test_multiple_dialects_normalization() {
777        // Postgres: case-insensitive
778        let mut pg = MappingSchema::new(Dialect::Postgres);
779        pg.add_table(&["T"], vec![("C".to_string(), DataType::Int)])
780            .unwrap();
781        assert!(pg.has_column(&["t"], "c"));
782
783        // MySQL: case-insensitive
784        let mut my = MappingSchema::new(Dialect::Mysql);
785        my.add_table(&["T"], vec![("C".to_string(), DataType::Int)])
786            .unwrap();
787        assert!(my.has_column(&["t"], "c"));
788
789        // Spark: case-sensitive
790        let mut sp = MappingSchema::new(Dialect::Spark);
791        sp.add_table(&["T"], vec![("C".to_string(), DataType::Int)])
792            .unwrap();
793        assert!(!sp.has_column(&["t"], "c"));
794        assert!(sp.has_column(&["T"], "C"));
795    }
796
797    // ── Complex data types ──────────────────────────────────────────────
798
799    #[test]
800    fn test_complex_data_types() {
801        let mut schema = MappingSchema::new(Dialect::Ansi);
802        schema
803            .add_table(
804                &["complex_table"],
805                vec![
806                    (
807                        "tags".to_string(),
808                        DataType::Array(Some(Box::new(DataType::Text))),
809                    ),
810                    ("metadata".to_string(), DataType::Json),
811                    (
812                        "coords".to_string(),
813                        DataType::Struct(vec![
814                            ("lat".to_string(), DataType::Double),
815                            ("lng".to_string(), DataType::Double),
816                        ]),
817                    ),
818                    (
819                        "lookup".to_string(),
820                        DataType::Map {
821                            key: Box::new(DataType::Text),
822                            value: Box::new(DataType::Int),
823                        },
824                    ),
825                ],
826            )
827            .unwrap();
828
829        assert_eq!(
830            schema.get_column_type(&["complex_table"], "tags").unwrap(),
831            DataType::Array(Some(Box::new(DataType::Text)))
832        );
833        assert_eq!(
834            schema
835                .get_column_type(&["complex_table"], "metadata")
836                .unwrap(),
837            DataType::Json
838        );
839    }
840
841    // ── dialect() accessor ──────────────────────────────────────────────
842
843    #[test]
844    fn test_schema_dialect() {
845        let schema = MappingSchema::new(Dialect::Snowflake);
846        assert_eq!(schema.dialect(), Dialect::Snowflake);
847    }
848}