1mod functions;
13mod parse;
14
15pub use functions::{FunctionCategory, SqlFunction};
16pub use parse::{parse_type, TypeParseError};
17
18#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23pub enum DataType {
24 Boolean,
27 SmallInt,
29 Integer,
31 BigInt,
33 Float,
35 Double,
37 Decimal { precision: u8, scale: u8 },
39
40 Varchar { max_length: Option<u32> },
43 Char { length: u32 },
45 Text,
47
48 Blob,
51
52 Date,
55 Time,
57 Timestamp { with_timezone: bool },
59 Interval,
61
62 Array(Box<DataType>),
65 Struct(Vec<(String, DataType)>),
67 Map(Box<DataType>, Box<DataType>),
69
70 Null,
73 Unknown,
75}
76
77impl DataType {
78 pub fn is_numeric(&self) -> bool {
80 matches!(
81 self,
82 DataType::SmallInt
83 | DataType::Integer
84 | DataType::BigInt
85 | DataType::Float
86 | DataType::Double
87 | DataType::Decimal { .. }
88 )
89 }
90
91 pub fn is_string(&self) -> bool {
93 matches!(
94 self,
95 DataType::Varchar { .. } | DataType::Char { .. } | DataType::Text
96 )
97 }
98
99 pub fn is_complex(&self) -> bool {
101 matches!(
102 self,
103 DataType::Array(_) | DataType::Struct(_) | DataType::Map(_, _)
104 )
105 }
106
107 pub fn is_temporal(&self) -> bool {
109 matches!(
110 self,
111 DataType::Date | DataType::Time | DataType::Timestamp { .. } | DataType::Interval
112 )
113 }
114
115 pub fn normalize(&self) -> DataType {
121 match self {
122 DataType::Text => DataType::Varchar { max_length: None },
123 DataType::Array(inner) => DataType::Array(Box::new(inner.normalize())),
124 DataType::Struct(fields) => DataType::Struct(
125 fields
126 .iter()
127 .map(|(name, dt)| (name.clone(), dt.normalize()))
128 .collect(),
129 ),
130 DataType::Map(k, v) => DataType::Map(Box::new(k.normalize()), Box::new(v.normalize())),
131 other => other.clone(),
132 }
133 }
134
135 pub fn to_backend_sql(&self) -> String {
140 match self {
141 DataType::Text => "VARCHAR".to_string(),
142 other => other.to_sql(),
143 }
144 }
145
146 pub fn to_sql(&self) -> String {
148 match self {
149 DataType::Boolean => "BOOLEAN".to_string(),
150 DataType::SmallInt => "SMALLINT".to_string(),
151 DataType::Integer => "INTEGER".to_string(),
152 DataType::BigInt => "BIGINT".to_string(),
153 DataType::Float => "FLOAT".to_string(),
154 DataType::Double => "DOUBLE".to_string(),
155 DataType::Decimal { precision, scale } => {
156 if *scale == 0 {
157 format!("DECIMAL({precision})")
158 } else {
159 format!("DECIMAL({precision},{scale})")
160 }
161 }
162 DataType::Varchar { max_length: None } => "VARCHAR".to_string(),
163 DataType::Varchar {
164 max_length: Some(len),
165 } => format!("VARCHAR({len})"),
166 DataType::Char { length } => format!("CHAR({length})"),
167 DataType::Text => "TEXT".to_string(),
168 DataType::Blob => "BLOB".to_string(),
169 DataType::Date => "DATE".to_string(),
170 DataType::Time => "TIME".to_string(),
171 DataType::Timestamp { with_timezone } => {
172 if *with_timezone {
173 "TIMESTAMP WITH TIME ZONE".to_string()
174 } else {
175 "TIMESTAMP".to_string()
176 }
177 }
178 DataType::Interval => "INTERVAL".to_string(),
179 DataType::Array(inner) => format!("{}[]", inner.to_sql()),
180 DataType::Struct(fields) => {
181 let field_strs: Vec<String> = fields
182 .iter()
183 .map(|(name, dt)| format!("{} {}", name, dt.to_sql()))
184 .collect();
185 format!("STRUCT({})", field_strs.join(", "))
186 }
187 DataType::Map(key, value) => {
188 format!("MAP({}, {})", key.to_sql(), value.to_sql())
189 }
190 DataType::Null => "NULL".to_string(),
191 DataType::Unknown => "UNKNOWN".to_string(),
192 }
193 }
194}
195
196impl std::fmt::Display for DataType {
197 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198 write!(f, "{}", self.to_sql())
199 }
200}
201
202#[derive(Debug, Clone, PartialEq, Eq)]
204pub struct TypedColumn {
205 pub data_type: DataType,
207 pub nullable: bool,
209}
210
211impl TypedColumn {
212 pub fn new(data_type: DataType, nullable: bool) -> Self {
214 Self {
215 data_type,
216 nullable,
217 }
218 }
219
220 pub fn nullable(data_type: DataType) -> Self {
222 Self::new(data_type, true)
223 }
224
225 pub fn not_null(data_type: DataType) -> Self {
227 Self::new(data_type, false)
228 }
229
230 pub fn unknown() -> Self {
232 Self::nullable(DataType::Unknown)
233 }
234}
235
236impl std::fmt::Display for TypedColumn {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 write!(f, "{}", self.data_type)?;
239 if !self.nullable {
240 write!(f, " NOT NULL")?;
241 }
242 Ok(())
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn test_data_type_display() {
252 assert_eq!(DataType::Integer.to_string(), "INTEGER");
253 assert_eq!(
254 DataType::Decimal {
255 precision: 10,
256 scale: 2
257 }
258 .to_string(),
259 "DECIMAL(10,2)"
260 );
261 assert_eq!(
262 DataType::Varchar { max_length: None }.to_string(),
263 "VARCHAR"
264 );
265 assert_eq!(
266 DataType::Varchar {
267 max_length: Some(255)
268 }
269 .to_string(),
270 "VARCHAR(255)"
271 );
272 assert_eq!(
273 DataType::Timestamp {
274 with_timezone: true
275 }
276 .to_string(),
277 "TIMESTAMP WITH TIME ZONE"
278 );
279 assert_eq!(
280 DataType::Array(Box::new(DataType::Integer)).to_string(),
281 "INTEGER[]"
282 );
283 }
284
285 #[test]
286 fn test_to_backend_sql_text_becomes_varchar() {
287 assert_eq!(DataType::Text.to_backend_sql(), "VARCHAR");
288 assert_eq!(DataType::Integer.to_backend_sql(), "INTEGER");
289 assert_eq!(
290 DataType::Varchar { max_length: None }.to_backend_sql(),
291 "VARCHAR"
292 );
293 }
294
295 #[test]
296 fn test_is_numeric() {
297 assert!(DataType::Integer.is_numeric());
298 assert!(DataType::BigInt.is_numeric());
299 assert!(DataType::Double.is_numeric());
300 assert!(DataType::Decimal {
301 precision: 10,
302 scale: 2
303 }
304 .is_numeric());
305 assert!(!DataType::Varchar { max_length: None }.is_numeric());
306 assert!(!DataType::Date.is_numeric());
307 }
308
309 #[test]
310 fn test_is_complex() {
311 assert!(DataType::Array(Box::new(DataType::Integer)).is_complex());
312 assert!(DataType::Struct(vec![("a".to_string(), DataType::Integer)]).is_complex());
313 assert!(DataType::Map(
314 Box::new(DataType::Varchar { max_length: None }),
315 Box::new(DataType::Integer)
316 )
317 .is_complex());
318 assert!(!DataType::Integer.is_complex());
319 assert!(!DataType::Varchar { max_length: None }.is_complex());
320 assert!(!DataType::Boolean.is_complex());
321 }
322
323 #[test]
324 fn test_map_to_sql() {
325 assert_eq!(
326 DataType::Map(
327 Box::new(DataType::Varchar { max_length: None }),
328 Box::new(DataType::Integer)
329 )
330 .to_sql(),
331 "MAP(VARCHAR, INTEGER)"
332 );
333 }
334
335 #[test]
338 fn test_normalize_text_to_varchar() {
339 assert_eq!(
340 DataType::Text.normalize(),
341 DataType::Varchar { max_length: None }
342 );
343 }
344
345 #[test]
346 fn test_normalize_scalar_unchanged() {
347 assert_eq!(DataType::Integer.normalize(), DataType::Integer);
348 assert_eq!(DataType::BigInt.normalize(), DataType::BigInt);
349 assert_eq!(DataType::Boolean.normalize(), DataType::Boolean);
350 assert_eq!(
351 DataType::Varchar { max_length: None }.normalize(),
352 DataType::Varchar { max_length: None }
353 );
354 assert_eq!(
355 DataType::Decimal {
356 precision: 10,
357 scale: 2
358 }
359 .normalize(),
360 DataType::Decimal {
361 precision: 10,
362 scale: 2
363 }
364 );
365 }
366
367 #[test]
368 fn test_normalize_array_recursive() {
369 let arr = DataType::Array(Box::new(DataType::Text));
371 assert_eq!(
372 arr.normalize(),
373 DataType::Array(Box::new(DataType::Varchar { max_length: None }))
374 );
375
376 let arr = DataType::Array(Box::new(DataType::Integer));
378 assert_eq!(
379 arr.normalize(),
380 DataType::Array(Box::new(DataType::Integer))
381 );
382 }
383
384 #[test]
385 fn test_normalize_struct_recursive() {
386 let s = DataType::Struct(vec![
387 ("a".to_string(), DataType::Text),
388 ("b".to_string(), DataType::Integer),
389 ]);
390 assert_eq!(
391 s.normalize(),
392 DataType::Struct(vec![
393 ("a".to_string(), DataType::Varchar { max_length: None }),
394 ("b".to_string(), DataType::Integer),
395 ])
396 );
397 }
398
399 #[test]
400 fn test_normalize_map_recursive() {
401 let m = DataType::Map(Box::new(DataType::Text), Box::new(DataType::Text));
402 assert_eq!(
403 m.normalize(),
404 DataType::Map(
405 Box::new(DataType::Varchar { max_length: None }),
406 Box::new(DataType::Varchar { max_length: None })
407 )
408 );
409 }
410
411 #[test]
412 fn test_normalize_deeply_nested() {
413 let s = DataType::Struct(vec![(
415 "a".to_string(),
416 DataType::Struct(vec![("x".to_string(), DataType::Text)]),
417 )]);
418 assert_eq!(
419 s.normalize(),
420 DataType::Struct(vec![(
421 "a".to_string(),
422 DataType::Struct(vec![(
423 "x".to_string(),
424 DataType::Varchar { max_length: None }
425 )]),
426 )])
427 );
428 }
429
430 #[test]
431 fn test_typed_column_display() {
432 let col = TypedColumn::not_null(DataType::Integer);
433 assert_eq!(col.to_string(), "INTEGER NOT NULL");
434
435 let col = TypedColumn::nullable(DataType::Varchar {
436 max_length: Some(100),
437 });
438 assert_eq!(col.to_string(), "VARCHAR(100)");
439 }
440}