vibesql/ast/
types.rs

1//! Data type specifications for the SQL AST.
2//!
3//! This module defines how data types are represented in the AST,
4//! separate from the runtime type system. Type names follow ISO SQL standards.
5
6use super::Ident;
7use crate::error::Span;
8
9/// Data type specification as it appears in SQL.
10#[derive(Debug, Clone, PartialEq)]
11pub struct DataTypeSpec {
12    pub kind: DataTypeKind,
13    pub span: Span,
14}
15
16impl DataTypeSpec {
17    pub fn new(kind: DataTypeKind, span: Span) -> Self {
18        Self { kind, span }
19    }
20}
21
22/// The kind of data type.
23///
24/// These are the canonical types used in the AST. The parser handles
25/// all standard SQL type aliases (INTEGER, BIGINT, VARCHAR, etc.) and
26/// normalizes them to these canonical forms.
27#[derive(Debug, Clone, PartialEq)]
28pub enum DataTypeKind {
29    /// Boolean type (BOOLEAN, BOOL)
30    Bool,
31
32    /// 32-bit signed integer (INTEGER, INT, INT32)
33    Int32,
34
35    /// 64-bit signed integer (BIGINT, INT64)
36    Int64,
37
38    /// 32-bit unsigned integer (UINTEGER, UINT32)
39    Uint32,
40
41    /// 64-bit unsigned integer (UBIGINT, UINT64)
42    Uint64,
43
44    /// 32-bit floating point (REAL, FLOAT32)
45    Float32,
46
47    /// 64-bit floating point (DOUBLE PRECISION, DOUBLE, FLOAT, FLOAT64)
48    Float64,
49
50    /// Fixed precision decimal (NUMERIC, DECIMAL)
51    Numeric {
52        precision: Option<u8>,
53        scale: Option<u8>,
54    },
55
56    /// Variable-length character string (VARCHAR, TEXT, STRING, CHAR)
57    Varchar { max_length: Option<u64> },
58
59    /// Variable-length binary data (VARBINARY, BYTEA, BYTES, BLOB)
60    Varbinary { max_length: Option<u64> },
61
62    /// Date (year, month, day)
63    Date,
64
65    /// Time of day
66    Time,
67
68    /// Date and time without timezone (DATETIME, TIMESTAMP WITHOUT TIME ZONE)
69    Datetime,
70
71    /// Date and time with timezone (TIMESTAMP, TIMESTAMP WITH TIME ZONE)
72    Timestamp,
73
74    /// Time interval
75    Interval,
76
77    /// Array of elements
78    Array(Box<DataTypeSpec>),
79
80    /// Struct with named fields (ROW type in standard SQL)
81    Struct(Vec<StructField>),
82
83    /// JSON data
84    Json,
85
86    /// Range of values
87    Range(Box<DataTypeSpec>),
88
89    /// UUID type
90    Uuid,
91
92    /// Named types (for user-defined types)
93    Named(Vec<Ident>),
94}
95
96/// Struct field in a STRUCT type.
97#[derive(Debug, Clone, PartialEq)]
98pub struct StructField {
99    pub name: Option<Ident>,
100    pub data_type: DataTypeSpec,
101}
102
103impl DataTypeKind {
104    /// Check if this type is a numeric type.
105    pub fn is_numeric(&self) -> bool {
106        matches!(
107            self,
108            DataTypeKind::Int32
109                | DataTypeKind::Int64
110                | DataTypeKind::Uint32
111                | DataTypeKind::Uint64
112                | DataTypeKind::Float32
113                | DataTypeKind::Float64
114                | DataTypeKind::Numeric { .. }
115        )
116    }
117
118    /// Check if this type is an integer type.
119    pub fn is_integer(&self) -> bool {
120        matches!(
121            self,
122            DataTypeKind::Int32 | DataTypeKind::Int64 | DataTypeKind::Uint32 | DataTypeKind::Uint64
123        )
124    }
125
126    /// Check if this type is a floating-point type.
127    pub fn is_floating_point(&self) -> bool {
128        matches!(self, DataTypeKind::Float32 | DataTypeKind::Float64)
129    }
130
131    /// Check if this type is a date/time type.
132    pub fn is_datetime(&self) -> bool {
133        matches!(
134            self,
135            DataTypeKind::Date
136                | DataTypeKind::Time
137                | DataTypeKind::Datetime
138                | DataTypeKind::Timestamp
139        )
140    }
141
142    /// Check if this type is a string type.
143    pub fn is_string(&self) -> bool {
144        matches!(self, DataTypeKind::Varchar { .. })
145    }
146}
147
148impl std::fmt::Display for DataTypeKind {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        match self {
151            DataTypeKind::Bool => write!(f, "BOOLEAN"),
152            DataTypeKind::Int32 => write!(f, "INTEGER"),
153            DataTypeKind::Int64 => write!(f, "BIGINT"),
154            DataTypeKind::Uint32 => write!(f, "UINTEGER"),
155            DataTypeKind::Uint64 => write!(f, "UBIGINT"),
156            DataTypeKind::Float32 => write!(f, "REAL"),
157            DataTypeKind::Float64 => write!(f, "DOUBLE PRECISION"),
158            DataTypeKind::Numeric { precision, scale } => {
159                write!(f, "NUMERIC")?;
160                if let Some(p) = precision {
161                    write!(f, "({}", p)?;
162                    if let Some(s) = scale {
163                        write!(f, ", {}", s)?;
164                    }
165                    write!(f, ")")?;
166                }
167                Ok(())
168            }
169            DataTypeKind::Varchar { max_length } => {
170                write!(f, "VARCHAR")?;
171                if let Some(len) = max_length {
172                    write!(f, "({})", len)?;
173                }
174                Ok(())
175            }
176            DataTypeKind::Varbinary { max_length } => {
177                write!(f, "VARBINARY")?;
178                if let Some(len) = max_length {
179                    write!(f, "({})", len)?;
180                }
181                Ok(())
182            }
183            DataTypeKind::Date => write!(f, "DATE"),
184            DataTypeKind::Time => write!(f, "TIME"),
185            DataTypeKind::Datetime => write!(f, "DATETIME"),
186            DataTypeKind::Timestamp => write!(f, "TIMESTAMP"),
187            DataTypeKind::Interval => write!(f, "INTERVAL"),
188            DataTypeKind::Array(elem) => write!(f, "ARRAY<{}>", elem.kind),
189            DataTypeKind::Struct(fields) => {
190                write!(f, "STRUCT<")?;
191                for (i, field) in fields.iter().enumerate() {
192                    if i > 0 {
193                        write!(f, ", ")?;
194                    }
195                    if let Some(name) = &field.name {
196                        write!(f, "{} ", name)?;
197                    }
198                    write!(f, "{}", field.data_type.kind)?;
199                }
200                write!(f, ">")
201            }
202            DataTypeKind::Json => write!(f, "JSON"),
203            DataTypeKind::Range(elem) => write!(f, "RANGE<{}>", elem.kind),
204            DataTypeKind::Uuid => write!(f, "UUID"),
205            DataTypeKind::Named(parts) => {
206                for (i, part) in parts.iter().enumerate() {
207                    if i > 0 {
208                        write!(f, ".")?;
209                    }
210                    write!(f, "{}", part)?;
211                }
212                Ok(())
213            }
214        }
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    #[test]
223    fn test_type_display() {
224        assert_eq!(format!("{}", DataTypeKind::Int64), "BIGINT");
225        assert_eq!(format!("{}", DataTypeKind::Int32), "INTEGER");
226        assert_eq!(format!("{}", DataTypeKind::Float64), "DOUBLE PRECISION");
227        assert_eq!(format!("{}", DataTypeKind::Float32), "REAL");
228        assert_eq!(format!("{}", DataTypeKind::Bool), "BOOLEAN");
229        assert_eq!(
230            format!(
231                "{}",
232                DataTypeKind::Numeric {
233                    precision: Some(10),
234                    scale: Some(2)
235                }
236            ),
237            "NUMERIC(10, 2)"
238        );
239        assert_eq!(
240            format!(
241                "{}",
242                DataTypeKind::Varchar {
243                    max_length: Some(255)
244                }
245            ),
246            "VARCHAR(255)"
247        );
248        assert_eq!(
249            format!(
250                "{}",
251                DataTypeKind::Varbinary {
252                    max_length: Some(100)
253                }
254            ),
255            "VARBINARY(100)"
256        );
257    }
258
259    #[test]
260    fn test_type_classification() {
261        assert!(DataTypeKind::Int64.is_numeric());
262        assert!(DataTypeKind::Int64.is_integer());
263        assert!(!DataTypeKind::Int64.is_floating_point());
264
265        assert!(DataTypeKind::Float64.is_numeric());
266        assert!(!DataTypeKind::Float64.is_integer());
267        assert!(DataTypeKind::Float64.is_floating_point());
268
269        assert!(DataTypeKind::Date.is_datetime());
270        assert!(DataTypeKind::Varchar { max_length: None }.is_string());
271    }
272}