Skip to main content

smelt_types/
lib.rs

1//! Type system definitions for smelt
2//!
3//! This crate provides the core type representations used throughout smelt:
4//! - `DataType`: SQL data types (INTEGER, VARCHAR, DECIMAL, etc.)
5//! - `TypedColumn`: Column with type and nullability
6//!
7//! These types are used by:
8//! - smelt-db for type checking and schema inference
9//! - smelt-cli for source configuration
10//! - smelt-lsp for type-aware editor features
11
12mod functions;
13mod parse;
14
15pub use functions::{FunctionCategory, SqlFunction};
16pub use parse::{parse_type, TypeParseError};
17
18/// SQL data types supported by smelt
19///
20/// This enum represents the logical SQL types. Backend-specific variations
21/// (e.g., DuckDB's HUGEINT) are mapped to these canonical types.
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23pub enum DataType {
24    // Numeric types
25    /// Boolean (TRUE/FALSE)
26    Boolean,
27    /// Small integer (2 bytes, -32768 to 32767)
28    SmallInt,
29    /// Integer (4 bytes)
30    Integer,
31    /// Big integer (8 bytes)
32    BigInt,
33    /// Single-precision floating point
34    Float,
35    /// Double-precision floating point
36    Double,
37    /// Exact decimal with precision and scale
38    Decimal { precision: u8, scale: u8 },
39
40    // String types
41    /// Variable-length string with optional max length
42    Varchar { max_length: Option<u32> },
43    /// Fixed-length string
44    Char { length: u32 },
45    /// Unbounded text
46    Text,
47
48    // Binary types
49    /// Binary large object
50    Blob,
51
52    // Date/Time types
53    /// Calendar date (year, month, day)
54    Date,
55    /// Time of day
56    Time,
57    /// Timestamp (date + time)
58    Timestamp { with_timezone: bool },
59    /// Time interval
60    Interval,
61
62    // Complex types
63    /// Array of elements
64    Array(Box<DataType>),
65    /// Struct with named fields: STRUCT(a INTEGER, b VARCHAR)
66    Struct(Vec<(String, DataType)>),
67    /// Map from key type to value type: MAP(VARCHAR, INTEGER)
68    Map(Box<DataType>, Box<DataType>),
69
70    // Special types
71    /// NULL literal type
72    Null,
73    /// Type could not be determined (error recovery)
74    Unknown,
75}
76
77impl DataType {
78    /// Check if this type is numeric (supports arithmetic operations)
79    pub fn is_numeric(&self) -> bool {
80        matches!(
81            self,
82            DataType::SmallInt
83                | DataType::Integer
84                | DataType::BigInt
85                | DataType::Float
86                | DataType::Double
87                | DataType::Decimal { .. }
88        )
89    }
90
91    /// Check if this type is a string type
92    pub fn is_string(&self) -> bool {
93        matches!(
94            self,
95            DataType::Varchar { .. } | DataType::Char { .. } | DataType::Text
96        )
97    }
98
99    /// Check if this type is a complex/nested type (Array, Struct, Map)
100    pub fn is_complex(&self) -> bool {
101        matches!(
102            self,
103            DataType::Array(_) | DataType::Struct(_) | DataType::Map(_, _)
104        )
105    }
106
107    /// Check if this type is a date/time type
108    pub fn is_temporal(&self) -> bool {
109        matches!(
110            self,
111            DataType::Date | DataType::Time | DataType::Timestamp { .. } | DataType::Interval
112        )
113    }
114
115    /// Normalize this type to its canonical form for comparison.
116    ///
117    /// - `Text` → `Varchar { max_length: None }` (canonical string type)
118    /// - Recursively normalizes Array elements, Struct fields, Map key/value
119    /// - All other types are returned as-is
120    pub fn normalize(&self) -> DataType {
121        match self {
122            DataType::Text => DataType::Varchar { max_length: None },
123            DataType::Array(inner) => DataType::Array(Box::new(inner.normalize())),
124            DataType::Struct(fields) => DataType::Struct(
125                fields
126                    .iter()
127                    .map(|(name, dt)| (name.clone(), dt.normalize()))
128                    .collect(),
129            ),
130            DataType::Map(k, v) => DataType::Map(Box::new(k.normalize()), Box::new(v.normalize())),
131            other => other.clone(),
132        }
133    }
134
135    /// Format as SQL type string for backend compilation.
136    ///
137    /// Translates smelt-internal types to what backends actually support:
138    /// - `Text` → `"VARCHAR"` (backends don't distinguish Text from VARCHAR)
139    pub fn to_backend_sql(&self) -> String {
140        match self {
141            DataType::Text => "VARCHAR".to_string(),
142            other => other.to_sql(),
143        }
144    }
145
146    /// Format as SQL type string for the default dialect
147    pub fn to_sql(&self) -> String {
148        match self {
149            DataType::Boolean => "BOOLEAN".to_string(),
150            DataType::SmallInt => "SMALLINT".to_string(),
151            DataType::Integer => "INTEGER".to_string(),
152            DataType::BigInt => "BIGINT".to_string(),
153            DataType::Float => "FLOAT".to_string(),
154            DataType::Double => "DOUBLE".to_string(),
155            DataType::Decimal { precision, scale } => {
156                if *scale == 0 {
157                    format!("DECIMAL({precision})")
158                } else {
159                    format!("DECIMAL({precision},{scale})")
160                }
161            }
162            DataType::Varchar { max_length: None } => "VARCHAR".to_string(),
163            DataType::Varchar {
164                max_length: Some(len),
165            } => format!("VARCHAR({len})"),
166            DataType::Char { length } => format!("CHAR({length})"),
167            DataType::Text => "TEXT".to_string(),
168            DataType::Blob => "BLOB".to_string(),
169            DataType::Date => "DATE".to_string(),
170            DataType::Time => "TIME".to_string(),
171            DataType::Timestamp { with_timezone } => {
172                if *with_timezone {
173                    "TIMESTAMP WITH TIME ZONE".to_string()
174                } else {
175                    "TIMESTAMP".to_string()
176                }
177            }
178            DataType::Interval => "INTERVAL".to_string(),
179            DataType::Array(inner) => format!("{}[]", inner.to_sql()),
180            DataType::Struct(fields) => {
181                let field_strs: Vec<String> = fields
182                    .iter()
183                    .map(|(name, dt)| format!("{} {}", name, dt.to_sql()))
184                    .collect();
185                format!("STRUCT({})", field_strs.join(", "))
186            }
187            DataType::Map(key, value) => {
188                format!("MAP({}, {})", key.to_sql(), value.to_sql())
189            }
190            DataType::Null => "NULL".to_string(),
191            DataType::Unknown => "UNKNOWN".to_string(),
192        }
193    }
194}
195
196impl std::fmt::Display for DataType {
197    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198        write!(f, "{}", self.to_sql())
199    }
200}
201
202/// A column with its data type and nullability
203#[derive(Debug, Clone, PartialEq, Eq)]
204pub struct TypedColumn {
205    /// The SQL data type
206    pub data_type: DataType,
207    /// Whether the column can contain NULL values
208    pub nullable: bool,
209}
210
211impl TypedColumn {
212    /// Create a new typed column
213    pub fn new(data_type: DataType, nullable: bool) -> Self {
214        Self {
215            data_type,
216            nullable,
217        }
218    }
219
220    /// Create a nullable column
221    pub fn nullable(data_type: DataType) -> Self {
222        Self::new(data_type, true)
223    }
224
225    /// Create a non-nullable column
226    pub fn not_null(data_type: DataType) -> Self {
227        Self::new(data_type, false)
228    }
229
230    /// Create an unknown type (for error recovery)
231    pub fn unknown() -> Self {
232        Self::nullable(DataType::Unknown)
233    }
234}
235
236impl std::fmt::Display for TypedColumn {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        write!(f, "{}", self.data_type)?;
239        if !self.nullable {
240            write!(f, " NOT NULL")?;
241        }
242        Ok(())
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn test_data_type_display() {
252        assert_eq!(DataType::Integer.to_string(), "INTEGER");
253        assert_eq!(
254            DataType::Decimal {
255                precision: 10,
256                scale: 2
257            }
258            .to_string(),
259            "DECIMAL(10,2)"
260        );
261        assert_eq!(
262            DataType::Varchar { max_length: None }.to_string(),
263            "VARCHAR"
264        );
265        assert_eq!(
266            DataType::Varchar {
267                max_length: Some(255)
268            }
269            .to_string(),
270            "VARCHAR(255)"
271        );
272        assert_eq!(
273            DataType::Timestamp {
274                with_timezone: true
275            }
276            .to_string(),
277            "TIMESTAMP WITH TIME ZONE"
278        );
279        assert_eq!(
280            DataType::Array(Box::new(DataType::Integer)).to_string(),
281            "INTEGER[]"
282        );
283    }
284
285    #[test]
286    fn test_to_backend_sql_text_becomes_varchar() {
287        assert_eq!(DataType::Text.to_backend_sql(), "VARCHAR");
288        assert_eq!(DataType::Integer.to_backend_sql(), "INTEGER");
289        assert_eq!(
290            DataType::Varchar { max_length: None }.to_backend_sql(),
291            "VARCHAR"
292        );
293    }
294
295    #[test]
296    fn test_is_numeric() {
297        assert!(DataType::Integer.is_numeric());
298        assert!(DataType::BigInt.is_numeric());
299        assert!(DataType::Double.is_numeric());
300        assert!(DataType::Decimal {
301            precision: 10,
302            scale: 2
303        }
304        .is_numeric());
305        assert!(!DataType::Varchar { max_length: None }.is_numeric());
306        assert!(!DataType::Date.is_numeric());
307    }
308
309    #[test]
310    fn test_is_complex() {
311        assert!(DataType::Array(Box::new(DataType::Integer)).is_complex());
312        assert!(DataType::Struct(vec![("a".to_string(), DataType::Integer)]).is_complex());
313        assert!(DataType::Map(
314            Box::new(DataType::Varchar { max_length: None }),
315            Box::new(DataType::Integer)
316        )
317        .is_complex());
318        assert!(!DataType::Integer.is_complex());
319        assert!(!DataType::Varchar { max_length: None }.is_complex());
320        assert!(!DataType::Boolean.is_complex());
321    }
322
323    #[test]
324    fn test_map_to_sql() {
325        assert_eq!(
326            DataType::Map(
327                Box::new(DataType::Varchar { max_length: None }),
328                Box::new(DataType::Integer)
329            )
330            .to_sql(),
331            "MAP(VARCHAR, INTEGER)"
332        );
333    }
334
335    // === normalize() tests ===
336
337    #[test]
338    fn test_normalize_text_to_varchar() {
339        assert_eq!(
340            DataType::Text.normalize(),
341            DataType::Varchar { max_length: None }
342        );
343    }
344
345    #[test]
346    fn test_normalize_scalar_unchanged() {
347        assert_eq!(DataType::Integer.normalize(), DataType::Integer);
348        assert_eq!(DataType::BigInt.normalize(), DataType::BigInt);
349        assert_eq!(DataType::Boolean.normalize(), DataType::Boolean);
350        assert_eq!(
351            DataType::Varchar { max_length: None }.normalize(),
352            DataType::Varchar { max_length: None }
353        );
354        assert_eq!(
355            DataType::Decimal {
356                precision: 10,
357                scale: 2
358            }
359            .normalize(),
360            DataType::Decimal {
361                precision: 10,
362                scale: 2
363            }
364        );
365    }
366
367    #[test]
368    fn test_normalize_array_recursive() {
369        // Array(Text) → Array(Varchar)
370        let arr = DataType::Array(Box::new(DataType::Text));
371        assert_eq!(
372            arr.normalize(),
373            DataType::Array(Box::new(DataType::Varchar { max_length: None }))
374        );
375
376        // Array(Integer) unchanged
377        let arr = DataType::Array(Box::new(DataType::Integer));
378        assert_eq!(
379            arr.normalize(),
380            DataType::Array(Box::new(DataType::Integer))
381        );
382    }
383
384    #[test]
385    fn test_normalize_struct_recursive() {
386        let s = DataType::Struct(vec![
387            ("a".to_string(), DataType::Text),
388            ("b".to_string(), DataType::Integer),
389        ]);
390        assert_eq!(
391            s.normalize(),
392            DataType::Struct(vec![
393                ("a".to_string(), DataType::Varchar { max_length: None }),
394                ("b".to_string(), DataType::Integer),
395            ])
396        );
397    }
398
399    #[test]
400    fn test_normalize_map_recursive() {
401        let m = DataType::Map(Box::new(DataType::Text), Box::new(DataType::Text));
402        assert_eq!(
403            m.normalize(),
404            DataType::Map(
405                Box::new(DataType::Varchar { max_length: None }),
406                Box::new(DataType::Varchar { max_length: None })
407            )
408        );
409    }
410
411    #[test]
412    fn test_normalize_deeply_nested() {
413        // STRUCT(a STRUCT(x Text)) → STRUCT(a STRUCT(x Varchar))
414        let s = DataType::Struct(vec![(
415            "a".to_string(),
416            DataType::Struct(vec![("x".to_string(), DataType::Text)]),
417        )]);
418        assert_eq!(
419            s.normalize(),
420            DataType::Struct(vec![(
421                "a".to_string(),
422                DataType::Struct(vec![(
423                    "x".to_string(),
424                    DataType::Varchar { max_length: None }
425                )]),
426            )])
427        );
428    }
429
430    #[test]
431    fn test_typed_column_display() {
432        let col = TypedColumn::not_null(DataType::Integer);
433        assert_eq!(col.to_string(), "INTEGER NOT NULL");
434
435        let col = TypedColumn::nullable(DataType::Varchar {
436            max_length: Some(100),
437        });
438        assert_eq!(col.to_string(), "VARCHAR(100)");
439    }
440}