Skip to main content

smelt_types/
parse.rs

1//! Type string parsing
2//!
3//! Parses SQL type strings (e.g., "VARCHAR(255)", "DECIMAL(10,2)") into DataType.
4//! Handles various SQL dialects and common aliases.
5
6use crate::DataType;
7use thiserror::Error;
8
9/// Error parsing a type string
10#[derive(Debug, Error, PartialEq, Eq)]
11pub enum TypeParseError {
12    #[error("empty type string")]
13    EmptyString,
14    #[error("unknown type: {0}")]
15    UnknownType(String),
16    #[error("invalid precision/scale for DECIMAL: {0}")]
17    InvalidDecimal(String),
18    #[error("invalid length for {type_name}: {value}")]
19    InvalidLength { type_name: String, value: String },
20    #[error("missing closing parenthesis")]
21    MissingCloseParen,
22    #[error("invalid STRUCT: {0}")]
23    InvalidStruct(String),
24    #[error("invalid MAP: {0}")]
25    InvalidMap(String),
26}
27
28/// Parse a SQL type string into a DataType
29///
30/// Supports common SQL type names and aliases:
31/// - Numeric: INT, INTEGER, BIGINT, SMALLINT, FLOAT, DOUBLE, REAL, DECIMAL, NUMERIC
32/// - String: VARCHAR, CHAR, TEXT, STRING
33/// - Boolean: BOOLEAN, BOOL
34/// - Date/Time: DATE, TIME, TIMESTAMP, TIMESTAMPTZ, INTERVAL
35/// - Binary: BLOB, BYTEA, BINARY
36///
37/// # Examples
38/// ```
39/// use smelt_types::parse_type;
40///
41/// let ty = parse_type("INTEGER").unwrap();
42/// let ty = parse_type("VARCHAR(255)").unwrap();
43/// let ty = parse_type("DECIMAL(10,2)").unwrap();
44/// let ty = parse_type("TIMESTAMP WITH TIME ZONE").unwrap();
45/// ```
46pub fn parse_type(type_str: &str) -> Result<DataType, TypeParseError> {
47    let type_str = type_str.trim();
48    if type_str.is_empty() {
49        return Err(TypeParseError::EmptyString);
50    }
51
52    let upper = type_str.to_uppercase();
53    parse_type_inner(&upper)
54}
55
56/// Recursive inner parser that handles complex types
57fn parse_type_inner(upper: &str) -> Result<DataType, TypeParseError> {
58    let upper = upper.trim();
59    if upper.is_empty() {
60        return Err(TypeParseError::EmptyString);
61    }
62
63    // Check for [] suffix (array bracket notation) — peel from right
64    if let Some(inner) = upper.strip_suffix("[]") {
65        let inner = inner.trim();
66        let inner_type = parse_type_inner(inner)?;
67        return Ok(DataType::Array(Box::new(inner_type)));
68    }
69
70    // Check for " ARRAY" suffix (SQL standard notation)
71    // Must not match "ARRAY" by itself or "ARRAY(...)"
72    if let Some(inner) = upper.strip_suffix(" ARRAY") {
73        let inner = inner.trim();
74        if !inner.is_empty() {
75            let inner_type = parse_type_inner(inner)?;
76            return Ok(DataType::Array(Box::new(inner_type)));
77        }
78    }
79
80    // Handle STRUCT(...) — must use matching-paren logic
81    if upper.starts_with("STRUCT(") || upper.starts_with("STRUCT (") {
82        let open = upper.find('(').unwrap();
83        let close = find_matching_paren(upper, open).ok_or(TypeParseError::MissingCloseParen)?;
84        // There should be nothing after the closing paren ([] was already handled above)
85        let trailing = upper[close + 1..].trim();
86        if !trailing.is_empty() {
87            return Err(TypeParseError::InvalidStruct(format!(
88                "unexpected trailing characters: {trailing}"
89            )));
90        }
91        let fields_str = &upper[open + 1..close];
92        return parse_struct_fields(fields_str);
93    }
94
95    // Handle MAP(...)
96    if upper.starts_with("MAP(") || upper.starts_with("MAP (") {
97        let open = upper.find('(').unwrap();
98        let close = find_matching_paren(upper, open).ok_or(TypeParseError::MissingCloseParen)?;
99        let trailing = upper[close + 1..].trim();
100        if !trailing.is_empty() {
101            return Err(TypeParseError::InvalidMap(format!(
102                "unexpected trailing characters: {trailing}"
103            )));
104        }
105        let params_str = &upper[open + 1..close];
106        return parse_map_params(params_str);
107    }
108
109    // Handle ARRAY(...) prefix notation (Spark style)
110    if upper.starts_with("ARRAY(") || upper.starts_with("ARRAY (") {
111        let open = upper.find('(').unwrap();
112        let close = find_matching_paren(upper, open).ok_or(TypeParseError::MissingCloseParen)?;
113        let trailing = upper[close + 1..].trim();
114        if !trailing.is_empty() {
115            return Err(TypeParseError::UnknownType(upper.to_string()));
116        }
117        let inner_str = &upper[open + 1..close];
118        let inner_type = parse_type_inner(inner_str)?;
119        return Ok(DataType::Array(Box::new(inner_type)));
120    }
121
122    // Handle parameterized scalar types (VARCHAR(...), DECIMAL(...), etc.)
123    if let Some(paren_pos) = upper.find('(') {
124        return parse_parameterized_type(upper, paren_pos);
125    }
126
127    // Handle multi-word types
128    if upper.starts_with("TIMESTAMP") {
129        return parse_timestamp_type(upper);
130    }
131
132    // Simple types without parameters
133    parse_simple_type(upper)
134}
135
136/// Parse simple (non-parameterized) types
137fn parse_simple_type(upper: &str) -> Result<DataType, TypeParseError> {
138    match upper {
139        // Boolean
140        "BOOLEAN" | "BOOL" => Ok(DataType::Boolean),
141
142        // Integer types
143        "TINYINT" | "INT1" => Ok(DataType::SmallInt),
144        "SMALLINT" | "INT2" => Ok(DataType::SmallInt),
145        "INT" | "INTEGER" | "INT4" => Ok(DataType::Integer),
146        "BIGINT" | "INT8" | "LONG" => Ok(DataType::BigInt),
147        "HUGEINT" | "INT16" => Ok(DataType::BigInt),
148
149        // Floating point
150        "REAL" | "FLOAT4" | "FLOAT" => Ok(DataType::Float),
151        "DOUBLE" | "FLOAT8" | "DOUBLE PRECISION" => Ok(DataType::Double),
152
153        // String types (without length)
154        "VARCHAR" | "STRING" | "TEXT" => Ok(DataType::Varchar { max_length: None }),
155        "CHAR" | "CHARACTER" => Ok(DataType::Char { length: 1 }),
156
157        // Date/Time
158        "DATE" => Ok(DataType::Date),
159        "TIME" => Ok(DataType::Time),
160        "TIMESTAMP" => Ok(DataType::Timestamp {
161            with_timezone: false,
162        }),
163        "TIMESTAMPTZ" => Ok(DataType::Timestamp {
164            with_timezone: true,
165        }),
166        "INTERVAL" => Ok(DataType::Interval),
167
168        // Binary
169        "BLOB" | "BYTEA" | "BINARY" | "VARBINARY" => Ok(DataType::Blob),
170
171        // Numeric without precision defaults to DECIMAL(18,0)
172        "NUMERIC" | "DECIMAL" => Ok(DataType::Decimal {
173            precision: 18,
174            scale: 0,
175        }),
176
177        _ => Err(TypeParseError::UnknownType(upper.to_string())),
178    }
179}
180
181/// Find the matching closing parenthesis for the opening paren at `open_pos`
182fn find_matching_paren(s: &str, open_pos: usize) -> Option<usize> {
183    let mut depth = 0;
184    for (i, c) in s[open_pos..].char_indices() {
185        match c {
186            '(' => depth += 1,
187            ')' => {
188                depth -= 1;
189                if depth == 0 {
190                    return Some(open_pos + i);
191                }
192            }
193            _ => {}
194        }
195    }
196    None
197}
198
199/// Split a string by commas at the top level (not inside nested parentheses or brackets)
200fn split_top_level_commas(s: &str) -> Vec<&str> {
201    let mut result = Vec::new();
202    let mut depth = 0;
203    let mut start = 0;
204    for (i, c) in s.char_indices() {
205        match c {
206            '(' | '[' => depth += 1,
207            ')' | ']' => depth -= 1,
208            ',' if depth == 0 => {
209                result.push(&s[start..i]);
210                start = i + 1;
211            }
212            _ => {}
213        }
214    }
215    result.push(&s[start..]);
216    result
217}
218
219/// Parse STRUCT field list: "a INTEGER, b VARCHAR" → vec of (name, DataType)
220fn parse_struct_fields(fields_str: &str) -> Result<DataType, TypeParseError> {
221    let fields_str = fields_str.trim();
222    if fields_str.is_empty() {
223        return Err(TypeParseError::InvalidStruct(
224            "empty field list".to_string(),
225        ));
226    }
227
228    let parts = split_top_level_commas(fields_str);
229    let mut fields = Vec::new();
230
231    for part in parts {
232        let part = part.trim();
233        if part.is_empty() {
234            return Err(TypeParseError::InvalidStruct("empty field".to_string()));
235        }
236
237        // Split into name and type: first whitespace-delimited token is the name,
238        // everything after is the type string
239        let first_space = part.find(|c: char| c.is_whitespace());
240        match first_space {
241            Some(pos) => {
242                let name = part[..pos].trim().to_lowercase();
243                let type_str = part[pos..].trim();
244                let dt = parse_type_inner(type_str)?;
245                fields.push((name, dt));
246            }
247            None => {
248                return Err(TypeParseError::InvalidStruct(format!(
249                    "field '{}' is missing a type",
250                    part
251                )));
252            }
253        }
254    }
255
256    Ok(DataType::Struct(fields))
257}
258
259/// Parse MAP parameters: "KEY_TYPE, VALUE_TYPE"
260fn parse_map_params(params_str: &str) -> Result<DataType, TypeParseError> {
261    let params_str = params_str.trim();
262    if params_str.is_empty() {
263        return Err(TypeParseError::InvalidMap(
264            "empty parameter list".to_string(),
265        ));
266    }
267
268    let parts = split_top_level_commas(params_str);
269    if parts.len() != 2 {
270        return Err(TypeParseError::InvalidMap(format!(
271            "expected 2 type parameters, got {}",
272            parts.len()
273        )));
274    }
275
276    let key_type = parse_type_inner(parts[0].trim())?;
277    let value_type = parse_type_inner(parts[1].trim())?;
278
279    Ok(DataType::Map(Box::new(key_type), Box::new(value_type)))
280}
281
282fn parse_parameterized_type(upper: &str, paren_pos: usize) -> Result<DataType, TypeParseError> {
283    let type_name = upper[..paren_pos].trim();
284    let params_str = &upper[paren_pos + 1..];
285
286    // Find closing paren
287    let close_pos = params_str
288        .find(')')
289        .ok_or(TypeParseError::MissingCloseParen)?;
290    let params = &params_str[..close_pos];
291
292    match type_name {
293        "VARCHAR" | "VARYING" | "CHARACTER VARYING" | "STRING" => {
294            let length = parse_single_number(params, "VARCHAR")?;
295            Ok(DataType::Varchar {
296                max_length: Some(length),
297            })
298        }
299        "CHAR" | "CHARACTER" => {
300            let length = parse_single_number(params, "CHAR")?;
301            Ok(DataType::Char { length })
302        }
303        "DECIMAL" | "NUMERIC" | "DEC" => parse_decimal_params(params),
304        "FLOAT" => {
305            // FLOAT(n) - if n <= 24, use Float; otherwise Double
306            let precision = parse_single_number(params, "FLOAT")?;
307            if precision <= 24 {
308                Ok(DataType::Float)
309            } else {
310                Ok(DataType::Double)
311            }
312        }
313        "TIME" => {
314            // TIME(precision) - we ignore precision for now
315            Ok(DataType::Time)
316        }
317        "TIMESTAMP" => {
318            // TIMESTAMP(precision) - we ignore precision for now
319            // Check for WITH TIME ZONE suffix after the closing paren
320            let suffix = &params_str[close_pos + 1..].trim();
321            let with_tz =
322                suffix.starts_with("WITH TIME ZONE") || suffix.starts_with("WITH TIMEZONE");
323            Ok(DataType::Timestamp {
324                with_timezone: with_tz,
325            })
326        }
327        _ => Err(TypeParseError::UnknownType(type_name.to_string())),
328    }
329}
330
331fn parse_timestamp_type(upper: &str) -> Result<DataType, TypeParseError> {
332    // Handle: TIMESTAMP, TIMESTAMPTZ, TIMESTAMP WITH TIME ZONE, TIMESTAMP WITHOUT TIME ZONE
333    let with_tz = upper.contains("WITH TIME ZONE")
334        || upper.contains("WITH TIMEZONE")
335        || upper == "TIMESTAMPTZ";
336    Ok(DataType::Timestamp {
337        with_timezone: with_tz,
338    })
339}
340
341fn parse_single_number(params: &str, type_name: &str) -> Result<u32, TypeParseError> {
342    params
343        .trim()
344        .parse::<u32>()
345        .map_err(|_| TypeParseError::InvalidLength {
346            type_name: type_name.to_string(),
347            value: params.to_string(),
348        })
349}
350
351fn parse_decimal_params(params: &str) -> Result<DataType, TypeParseError> {
352    let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
353
354    match parts.len() {
355        1 => {
356            // DECIMAL(precision)
357            let precision = parts[0]
358                .parse::<u8>()
359                .map_err(|_| TypeParseError::InvalidDecimal(params.to_string()))?;
360            Ok(DataType::Decimal {
361                precision,
362                scale: 0,
363            })
364        }
365        2 => {
366            // DECIMAL(precision, scale)
367            let precision = parts[0]
368                .parse::<u8>()
369                .map_err(|_| TypeParseError::InvalidDecimal(params.to_string()))?;
370            let scale = parts[1]
371                .parse::<u8>()
372                .map_err(|_| TypeParseError::InvalidDecimal(params.to_string()))?;
373            Ok(DataType::Decimal { precision, scale })
374        }
375        _ => Err(TypeParseError::InvalidDecimal(params.to_string())),
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382
383    #[test]
384    fn test_parse_simple_types() {
385        assert_eq!(parse_type("INTEGER").unwrap(), DataType::Integer);
386        assert_eq!(parse_type("int").unwrap(), DataType::Integer);
387        assert_eq!(parse_type("BIGINT").unwrap(), DataType::BigInt);
388        assert_eq!(parse_type("BOOLEAN").unwrap(), DataType::Boolean);
389        assert_eq!(parse_type("bool").unwrap(), DataType::Boolean);
390        assert_eq!(parse_type("DATE").unwrap(), DataType::Date);
391        assert_eq!(
392            parse_type("VARCHAR").unwrap(),
393            DataType::Varchar { max_length: None }
394        );
395    }
396
397    #[test]
398    fn test_parse_varchar_with_length() {
399        assert_eq!(
400            parse_type("VARCHAR(255)").unwrap(),
401            DataType::Varchar {
402                max_length: Some(255)
403            }
404        );
405        assert_eq!(
406            parse_type("varchar(100)").unwrap(),
407            DataType::Varchar {
408                max_length: Some(100)
409            }
410        );
411    }
412
413    #[test]
414    fn test_parse_char_with_length() {
415        assert_eq!(
416            parse_type("CHAR(10)").unwrap(),
417            DataType::Char { length: 10 }
418        );
419        assert_eq!(parse_type("CHAR").unwrap(), DataType::Char { length: 1 });
420    }
421
422    #[test]
423    fn test_parse_decimal() {
424        assert_eq!(
425            parse_type("DECIMAL(10,2)").unwrap(),
426            DataType::Decimal {
427                precision: 10,
428                scale: 2
429            }
430        );
431        assert_eq!(
432            parse_type("DECIMAL(18)").unwrap(),
433            DataType::Decimal {
434                precision: 18,
435                scale: 0
436            }
437        );
438        assert_eq!(
439            parse_type("NUMERIC(5, 3)").unwrap(),
440            DataType::Decimal {
441                precision: 5,
442                scale: 3
443            }
444        );
445        // Without parameters
446        assert_eq!(
447            parse_type("DECIMAL").unwrap(),
448            DataType::Decimal {
449                precision: 18,
450                scale: 0
451            }
452        );
453    }
454
455    #[test]
456    fn test_parse_timestamp() {
457        assert_eq!(
458            parse_type("TIMESTAMP").unwrap(),
459            DataType::Timestamp {
460                with_timezone: false
461            }
462        );
463        assert_eq!(
464            parse_type("TIMESTAMP WITH TIME ZONE").unwrap(),
465            DataType::Timestamp {
466                with_timezone: true
467            }
468        );
469        assert_eq!(
470            parse_type("TIMESTAMPTZ").unwrap(),
471            DataType::Timestamp {
472                with_timezone: true
473            }
474        );
475    }
476
477    #[test]
478    fn test_parse_float_precision() {
479        assert_eq!(parse_type("FLOAT").unwrap(), DataType::Float);
480        assert_eq!(parse_type("FLOAT(24)").unwrap(), DataType::Float);
481        assert_eq!(parse_type("FLOAT(53)").unwrap(), DataType::Double);
482    }
483
484    #[test]
485    fn test_parse_aliases() {
486        assert_eq!(parse_type("INT").unwrap(), DataType::Integer);
487        assert_eq!(parse_type("INT4").unwrap(), DataType::Integer);
488        assert_eq!(parse_type("INT8").unwrap(), DataType::BigInt);
489        assert_eq!(parse_type("REAL").unwrap(), DataType::Float);
490        assert_eq!(parse_type("DOUBLE PRECISION").unwrap(), DataType::Double);
491        assert_eq!(
492            parse_type("TEXT").unwrap(),
493            DataType::Varchar { max_length: None }
494        );
495        assert_eq!(
496            parse_type("STRING").unwrap(),
497            DataType::Varchar { max_length: None }
498        );
499    }
500
501    #[test]
502    fn test_parse_errors() {
503        assert!(matches!(parse_type(""), Err(TypeParseError::EmptyString)));
504        assert!(matches!(
505            parse_type("FOOBAR"),
506            Err(TypeParseError::UnknownType(_))
507        ));
508        assert!(matches!(
509            parse_type("VARCHAR(abc)"),
510            Err(TypeParseError::InvalidLength { .. })
511        ));
512        assert!(matches!(
513            parse_type("DECIMAL(a,b)"),
514            Err(TypeParseError::InvalidDecimal(_))
515        ));
516    }
517
518    #[test]
519    fn test_case_insensitivity() {
520        assert_eq!(parse_type("integer").unwrap(), DataType::Integer);
521        assert_eq!(parse_type("INTEGER").unwrap(), DataType::Integer);
522        assert_eq!(parse_type("Integer").unwrap(), DataType::Integer);
523        assert_eq!(
524            parse_type("varchar(100)").unwrap(),
525            DataType::Varchar {
526                max_length: Some(100)
527            }
528        );
529    }
530
531    #[test]
532    fn test_whitespace_handling() {
533        assert_eq!(parse_type("  INTEGER  ").unwrap(), DataType::Integer);
534        assert_eq!(
535            parse_type("DECIMAL( 10 , 2 )").unwrap(),
536            DataType::Decimal {
537                precision: 10,
538                scale: 2
539            }
540        );
541    }
542
543    // === Complex type parsing tests ===
544
545    #[test]
546    fn test_parse_array_bracket_notation() {
547        assert_eq!(
548            parse_type("INTEGER[]").unwrap(),
549            DataType::Array(Box::new(DataType::Integer))
550        );
551        assert_eq!(
552            parse_type("VARCHAR[]").unwrap(),
553            DataType::Array(Box::new(DataType::Varchar { max_length: None }))
554        );
555        assert_eq!(
556            parse_type("BOOLEAN[]").unwrap(),
557            DataType::Array(Box::new(DataType::Boolean))
558        );
559    }
560
561    #[test]
562    fn test_parse_array_suffix_notation() {
563        assert_eq!(
564            parse_type("INTEGER ARRAY").unwrap(),
565            DataType::Array(Box::new(DataType::Integer))
566        );
567        assert_eq!(
568            parse_type("VARCHAR ARRAY").unwrap(),
569            DataType::Array(Box::new(DataType::Varchar { max_length: None }))
570        );
571    }
572
573    #[test]
574    fn test_parse_array_prefix_notation() {
575        assert_eq!(
576            parse_type("ARRAY(INTEGER)").unwrap(),
577            DataType::Array(Box::new(DataType::Integer))
578        );
579        assert_eq!(
580            parse_type("ARRAY(VARCHAR)").unwrap(),
581            DataType::Array(Box::new(DataType::Varchar { max_length: None }))
582        );
583    }
584
585    #[test]
586    fn test_parse_nested_arrays() {
587        assert_eq!(
588            parse_type("BIGINT[][]").unwrap(),
589            DataType::Array(Box::new(DataType::Array(Box::new(DataType::BigInt))))
590        );
591    }
592
593    #[test]
594    fn test_parse_struct_simple() {
595        assert_eq!(
596            parse_type("STRUCT(a INTEGER, b VARCHAR)").unwrap(),
597            DataType::Struct(vec![
598                ("a".to_string(), DataType::Integer),
599                ("b".to_string(), DataType::Varchar { max_length: None }),
600            ])
601        );
602    }
603
604    #[test]
605    fn test_parse_struct_aliases() {
606        // INT should resolve to INTEGER, BOOL to BOOLEAN
607        assert_eq!(
608            parse_type("STRUCT(a INT, b BOOL)").unwrap(),
609            DataType::Struct(vec![
610                ("a".to_string(), DataType::Integer),
611                ("b".to_string(), DataType::Boolean),
612            ])
613        );
614    }
615
616    #[test]
617    fn test_parse_struct_with_array_field() {
618        assert_eq!(
619            parse_type("STRUCT(a INTEGER[])").unwrap(),
620            DataType::Struct(vec![(
621                "a".to_string(),
622                DataType::Array(Box::new(DataType::Integer))
623            ),])
624        );
625    }
626
627    #[test]
628    fn test_parse_nested_struct() {
629        assert_eq!(
630            parse_type("STRUCT(a STRUCT(x INTEGER))").unwrap(),
631            DataType::Struct(vec![(
632                "a".to_string(),
633                DataType::Struct(vec![("x".to_string(), DataType::Integer),])
634            ),])
635        );
636    }
637
638    #[test]
639    fn test_parse_struct_array() {
640        // Array of structs
641        assert_eq!(
642            parse_type("STRUCT(a INTEGER, b VARCHAR)[]").unwrap(),
643            DataType::Array(Box::new(DataType::Struct(vec![
644                ("a".to_string(), DataType::Integer),
645                ("b".to_string(), DataType::Varchar { max_length: None }),
646            ])))
647        );
648    }
649
650    #[test]
651    fn test_parse_map_simple() {
652        assert_eq!(
653            parse_type("MAP(VARCHAR, INTEGER)").unwrap(),
654            DataType::Map(
655                Box::new(DataType::Varchar { max_length: None }),
656                Box::new(DataType::Integer)
657            )
658        );
659    }
660
661    #[test]
662    fn test_parse_map_with_complex_value() {
663        assert_eq!(
664            parse_type("MAP(VARCHAR, STRUCT(a INTEGER))").unwrap(),
665            DataType::Map(
666                Box::new(DataType::Varchar { max_length: None }),
667                Box::new(DataType::Struct(
668                    vec![("a".to_string(), DataType::Integer),]
669                ))
670            )
671        );
672    }
673
674    #[test]
675    fn test_parse_struct_with_map_field() {
676        assert_eq!(
677            parse_type("STRUCT(a INTEGER[], b MAP(VARCHAR, INTEGER))").unwrap(),
678            DataType::Struct(vec![
679                (
680                    "a".to_string(),
681                    DataType::Array(Box::new(DataType::Integer))
682                ),
683                (
684                    "b".to_string(),
685                    DataType::Map(
686                        Box::new(DataType::Varchar { max_length: None }),
687                        Box::new(DataType::Integer)
688                    )
689                ),
690            ])
691        );
692    }
693
694    #[test]
695    fn test_parse_deeply_nested() {
696        // STRUCT(a STRUCT(x INTEGER, y VARCHAR), b BIGINT)
697        assert_eq!(
698            parse_type("STRUCT(a STRUCT(x INTEGER, y VARCHAR), b BIGINT)").unwrap(),
699            DataType::Struct(vec![
700                (
701                    "a".to_string(),
702                    DataType::Struct(vec![
703                        ("x".to_string(), DataType::Integer),
704                        ("y".to_string(), DataType::Varchar { max_length: None }),
705                    ])
706                ),
707                ("b".to_string(), DataType::BigInt),
708            ])
709        );
710    }
711
712    #[test]
713    fn test_parse_struct_with_decimal_field() {
714        // DECIMAL(10,2) has commas inside parens — must not split on them
715        assert_eq!(
716            parse_type("STRUCT(a DECIMAL(10,2), b INTEGER)").unwrap(),
717            DataType::Struct(vec![
718                (
719                    "a".to_string(),
720                    DataType::Decimal {
721                        precision: 10,
722                        scale: 2
723                    }
724                ),
725                ("b".to_string(), DataType::Integer),
726            ])
727        );
728    }
729
730    #[test]
731    fn test_parse_complex_type_errors() {
732        assert!(parse_type("STRUCT()").is_err());
733        assert!(parse_type("STRUCT(a)").is_err()); // missing type
734        assert!(parse_type("MAP(VARCHAR)").is_err()); // missing value type
735        assert!(parse_type("MAP()").is_err());
736    }
737
738    #[test]
739    fn test_round_trip_all_types() {
740        // Verify: DataType → to_sql() → parse_type() → DataType for all canonical types.
741        // NOTE: DataType::Text is NOT round-trip safe (Text.to_sql() = "TEXT",
742        // parse_type("TEXT") = Varchar). Use normalize() before round-trip testing.
743        let types = vec![
744            DataType::Boolean,
745            DataType::SmallInt,
746            DataType::Integer,
747            DataType::BigInt,
748            DataType::Float,
749            DataType::Double,
750            DataType::Decimal {
751                precision: 10,
752                scale: 2,
753            },
754            DataType::Decimal {
755                precision: 18,
756                scale: 0,
757            },
758            DataType::Varchar { max_length: None },
759            DataType::Varchar {
760                max_length: Some(255),
761            },
762            DataType::Char { length: 10 },
763            DataType::Date,
764            DataType::Time,
765            DataType::Timestamp {
766                with_timezone: false,
767            },
768            DataType::Timestamp {
769                with_timezone: true,
770            },
771            DataType::Interval,
772            DataType::Blob,
773            // Complex types
774            DataType::Array(Box::new(DataType::Integer)),
775            DataType::Array(Box::new(DataType::Array(Box::new(DataType::BigInt)))),
776            DataType::Struct(vec![
777                ("a".to_string(), DataType::Integer),
778                ("b".to_string(), DataType::Varchar { max_length: None }),
779            ]),
780            DataType::Struct(vec![(
781                "nested".to_string(),
782                DataType::Struct(vec![("x".to_string(), DataType::BigInt)]),
783            )]),
784            DataType::Map(
785                Box::new(DataType::Varchar { max_length: None }),
786                Box::new(DataType::Integer),
787            ),
788            DataType::Map(
789                Box::new(DataType::Varchar { max_length: None }),
790                Box::new(DataType::Struct(vec![("a".to_string(), DataType::Integer)])),
791            ),
792            // Array of struct
793            DataType::Array(Box::new(DataType::Struct(vec![
794                ("id".to_string(), DataType::Integer),
795                ("name".to_string(), DataType::Varchar { max_length: None }),
796            ]))),
797        ];
798
799        for dt in &types {
800            let sql = dt.to_sql();
801            let parsed = parse_type(&sql).unwrap_or_else(|e| {
802                panic!(
803                    "Failed to parse to_sql() output '{}' for {:?}: {}",
804                    sql, dt, e
805                )
806            });
807            assert_eq!(
808                dt, &parsed,
809                "Round-trip failed for {:?}: to_sql()='{}', parsed back={:?}",
810                dt, sql, parsed
811            );
812        }
813    }
814
815    #[test]
816    fn test_round_trip_normalized_text() {
817        // Text normalizes to Varchar, which does round-trip
818        let dt = DataType::Text.normalize();
819        let sql = dt.to_sql();
820        let parsed = parse_type(&sql).unwrap();
821        assert_eq!(dt, parsed);
822    }
823
824    #[test]
825    fn test_parse_complex_case_insensitive() {
826        assert_eq!(
827            parse_type("struct(a integer)").unwrap(),
828            DataType::Struct(vec![("a".to_string(), DataType::Integer),])
829        );
830        assert_eq!(
831            parse_type("map(varchar, integer)").unwrap(),
832            DataType::Map(
833                Box::new(DataType::Varchar { max_length: None }),
834                Box::new(DataType::Integer)
835            )
836        );
837    }
838}