1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
9pub struct SqlTypeName(String);
10
11impl SqlTypeName {
12 pub fn new(input: impl AsRef<str>) -> Result<Self, SqlTypeError> {
18 validate_type_label(input.as_ref()).map(|value| Self(value.to_owned()))
19 }
20
21 #[must_use]
23 pub fn from_scalar(scalar: SqlScalarType) -> Self {
24 Self(scalar.as_str().to_owned())
25 }
26
27 #[must_use]
29 pub fn as_str(&self) -> &str {
30 &self.0
31 }
32}
33
34impl AsRef<str> for SqlTypeName {
35 fn as_ref(&self) -> &str {
36 self.as_str()
37 }
38}
39
40impl fmt::Display for SqlTypeName {
41 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
42 formatter.write_str(self.as_str())
43 }
44}
45
46impl FromStr for SqlTypeName {
47 type Err = SqlTypeError;
48
49 fn from_str(input: &str) -> Result<Self, Self::Err> {
50 Self::new(input)
51 }
52}
53
54impl TryFrom<&str> for SqlTypeName {
55 type Error = SqlTypeError;
56
57 fn try_from(value: &str) -> Result<Self, Self::Error> {
58 Self::new(value)
59 }
60}
61
62#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
64pub enum SqlScalarType {
65 #[default]
66 Text,
67 Integer,
68 BigInt,
69 Boolean,
70 Decimal,
71 Float,
72 Date,
73 Time,
74 Timestamp,
75 Json,
76 Uuid,
77 Binary,
78}
79
80impl SqlScalarType {
81 #[must_use]
83 pub const fn as_str(self) -> &'static str {
84 match self {
85 Self::Text => "text",
86 Self::Integer => "integer",
87 Self::BigInt => "bigint",
88 Self::Boolean => "boolean",
89 Self::Decimal => "decimal",
90 Self::Float => "float",
91 Self::Date => "date",
92 Self::Time => "time",
93 Self::Timestamp => "timestamp",
94 Self::Json => "json",
95 Self::Uuid => "uuid",
96 Self::Binary => "binary",
97 }
98 }
99}
100
101impl fmt::Display for SqlScalarType {
102 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
103 formatter.write_str(self.as_str())
104 }
105}
106
107impl FromStr for SqlScalarType {
108 type Err = SqlTypeError;
109
110 fn from_str(input: &str) -> Result<Self, Self::Err> {
111 match normalized_type_label(input)?.as_str() {
112 "text" | "string" | "varchar" | "character varying" | "char" => Ok(Self::Text),
113 "int" | "integer" => Ok(Self::Integer),
114 "bigint" | "big int" => Ok(Self::BigInt),
115 "bool" | "boolean" => Ok(Self::Boolean),
116 "decimal" | "numeric" => Ok(Self::Decimal),
117 "float" | "real" | "double" | "double precision" => Ok(Self::Float),
118 "date" => Ok(Self::Date),
119 "time" => Ok(Self::Time),
120 "timestamp" | "datetime" => Ok(Self::Timestamp),
121 "json" | "jsonb" => Ok(Self::Json),
122 "uuid" => Ok(Self::Uuid),
123 "binary" | "blob" | "bytea" => Ok(Self::Binary),
124 _ => Err(SqlTypeError::UnknownScalar),
125 }
126 }
127}
128
129impl TryFrom<&str> for SqlScalarType {
130 type Error = SqlTypeError;
131
132 fn try_from(value: &str) -> Result<Self, Self::Error> {
133 value.parse()
134 }
135}
136
137#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
139pub enum SqlTypeModifier {
140 Array,
141 Nullable,
142 NotNull,
143 Precision { precision: u16, scale: Option<u16> },
144}
145
146impl fmt::Display for SqlTypeModifier {
147 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
148 match self {
149 Self::Array => formatter.write_str("ARRAY"),
150 Self::Nullable => formatter.write_str("NULL"),
151 Self::NotNull => formatter.write_str("NOT NULL"),
152 Self::Precision { precision, scale } => {
153 if let Some(scale) = scale {
154 write!(formatter, "({precision}, {scale})")
155 } else {
156 write!(formatter, "({precision})")
157 }
158 },
159 }
160 }
161}
162
163#[derive(Clone, Copy, Debug, Eq, PartialEq)]
165pub enum SqlTypeError {
166 Empty,
167 ControlCharacter,
168 UnknownScalar,
169}
170
171impl fmt::Display for SqlTypeError {
172 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
173 match self {
174 Self::Empty => formatter.write_str("SQL type label cannot be empty"),
175 Self::ControlCharacter => {
176 formatter.write_str("SQL type label cannot contain control characters")
177 },
178 Self::UnknownScalar => formatter.write_str("unknown SQL scalar type label"),
179 }
180 }
181}
182
183impl Error for SqlTypeError {}
184
185fn validate_type_label(input: &str) -> Result<&str, SqlTypeError> {
186 let trimmed = input.trim();
187 if trimmed.is_empty() {
188 return Err(SqlTypeError::Empty);
189 }
190 if trimmed.chars().any(char::is_control) {
191 return Err(SqlTypeError::ControlCharacter);
192 }
193 Ok(trimmed)
194}
195
196fn normalized_type_label(input: &str) -> Result<String, SqlTypeError> {
197 let trimmed = validate_type_label(input)?;
198 Ok(trimmed
199 .replace('_', " ")
200 .split_whitespace()
201 .collect::<Vec<_>>()
202 .join(" ")
203 .to_ascii_lowercase())
204}
205
206#[cfg(test)]
207mod tests {
208 use super::{SqlScalarType, SqlTypeError, SqlTypeModifier, SqlTypeName};
209
210 #[test]
211 fn parses_common_scalar_types() -> Result<(), SqlTypeError> {
212 assert_eq!("varchar".parse::<SqlScalarType>()?, SqlScalarType::Text);
213 assert_eq!("numeric".parse::<SqlScalarType>()?, SqlScalarType::Decimal);
214 assert_eq!("blob".parse::<SqlScalarType>()?, SqlScalarType::Binary);
215 Ok(())
216 }
217
218 #[test]
219 fn validates_type_names() -> Result<(), SqlTypeError> {
220 let name = SqlTypeName::new(" NUMERIC ")?;
221 assert_eq!(name.as_str(), "NUMERIC");
222 assert_eq!(
223 SqlTypeName::from_scalar(SqlScalarType::Uuid).to_string(),
224 "uuid"
225 );
226 assert_eq!(SqlTypeName::new(""), Err(SqlTypeError::Empty));
227 Ok(())
228 }
229
230 #[test]
231 fn renders_modifiers() {
232 assert_eq!(SqlTypeModifier::NotNull.to_string(), "NOT NULL");
233 assert_eq!(
234 SqlTypeModifier::Precision {
235 precision: 10,
236 scale: Some(2)
237 }
238 .to_string(),
239 "(10, 2)"
240 );
241 }
242}