Skip to main content

use_sql_type/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7/// A SQL type name label.
8#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
9pub struct SqlTypeName(String);
10
11impl SqlTypeName {
12    /// Creates a SQL type name label.
13    ///
14    /// # Errors
15    ///
16    /// Returns [`SqlTypeError`] when the label is empty or contains control characters.
17    pub fn new(input: impl AsRef<str>) -> Result<Self, SqlTypeError> {
18        validate_type_label(input.as_ref()).map(|value| Self(value.to_owned()))
19    }
20
21    /// Creates a canonical type name from a scalar type.
22    #[must_use]
23    pub fn from_scalar(scalar: SqlScalarType) -> Self {
24        Self(scalar.as_str().to_owned())
25    }
26
27    /// Returns the stored type name.
28    #[must_use]
29    pub fn as_str(&self) -> &str {
30        &self.0
31    }
32}
33
34impl AsRef<str> for SqlTypeName {
35    fn as_ref(&self) -> &str {
36        self.as_str()
37    }
38}
39
40impl fmt::Display for SqlTypeName {
41    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
42        formatter.write_str(self.as_str())
43    }
44}
45
46impl FromStr for SqlTypeName {
47    type Err = SqlTypeError;
48
49    fn from_str(input: &str) -> Result<Self, Self::Err> {
50        Self::new(input)
51    }
52}
53
54impl TryFrom<&str> for SqlTypeName {
55    type Error = SqlTypeError;
56
57    fn try_from(value: &str) -> Result<Self, Self::Error> {
58        Self::new(value)
59    }
60}
61
62/// Common scalar SQL type labels.
63#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
64pub enum SqlScalarType {
65    #[default]
66    Text,
67    Integer,
68    BigInt,
69    Boolean,
70    Decimal,
71    Float,
72    Date,
73    Time,
74    Timestamp,
75    Json,
76    Uuid,
77    Binary,
78}
79
80impl SqlScalarType {
81    /// Returns the stable lowercase scalar type label.
82    #[must_use]
83    pub const fn as_str(self) -> &'static str {
84        match self {
85            Self::Text => "text",
86            Self::Integer => "integer",
87            Self::BigInt => "bigint",
88            Self::Boolean => "boolean",
89            Self::Decimal => "decimal",
90            Self::Float => "float",
91            Self::Date => "date",
92            Self::Time => "time",
93            Self::Timestamp => "timestamp",
94            Self::Json => "json",
95            Self::Uuid => "uuid",
96            Self::Binary => "binary",
97        }
98    }
99}
100
101impl fmt::Display for SqlScalarType {
102    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
103        formatter.write_str(self.as_str())
104    }
105}
106
107impl FromStr for SqlScalarType {
108    type Err = SqlTypeError;
109
110    fn from_str(input: &str) -> Result<Self, Self::Err> {
111        match normalized_type_label(input)?.as_str() {
112            "text" | "string" | "varchar" | "character varying" | "char" => Ok(Self::Text),
113            "int" | "integer" => Ok(Self::Integer),
114            "bigint" | "big int" => Ok(Self::BigInt),
115            "bool" | "boolean" => Ok(Self::Boolean),
116            "decimal" | "numeric" => Ok(Self::Decimal),
117            "float" | "real" | "double" | "double precision" => Ok(Self::Float),
118            "date" => Ok(Self::Date),
119            "time" => Ok(Self::Time),
120            "timestamp" | "datetime" => Ok(Self::Timestamp),
121            "json" | "jsonb" => Ok(Self::Json),
122            "uuid" => Ok(Self::Uuid),
123            "binary" | "blob" | "bytea" => Ok(Self::Binary),
124            _ => Err(SqlTypeError::UnknownScalar),
125        }
126    }
127}
128
129impl TryFrom<&str> for SqlScalarType {
130    type Error = SqlTypeError;
131
132    fn try_from(value: &str) -> Result<Self, Self::Error> {
133        value.parse()
134    }
135}
136
137/// Lightweight SQL type modifier labels.
138#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
139pub enum SqlTypeModifier {
140    Array,
141    Nullable,
142    NotNull,
143    Precision { precision: u16, scale: Option<u16> },
144}
145
146impl fmt::Display for SqlTypeModifier {
147    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
148        match self {
149            Self::Array => formatter.write_str("ARRAY"),
150            Self::Nullable => formatter.write_str("NULL"),
151            Self::NotNull => formatter.write_str("NOT NULL"),
152            Self::Precision { precision, scale } => {
153                if let Some(scale) = scale {
154                    write!(formatter, "({precision}, {scale})")
155                } else {
156                    write!(formatter, "({precision})")
157                }
158            },
159        }
160    }
161}
162
163/// Error returned when SQL type labels are invalid.
164#[derive(Clone, Copy, Debug, Eq, PartialEq)]
165pub enum SqlTypeError {
166    Empty,
167    ControlCharacter,
168    UnknownScalar,
169}
170
171impl fmt::Display for SqlTypeError {
172    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
173        match self {
174            Self::Empty => formatter.write_str("SQL type label cannot be empty"),
175            Self::ControlCharacter => {
176                formatter.write_str("SQL type label cannot contain control characters")
177            },
178            Self::UnknownScalar => formatter.write_str("unknown SQL scalar type label"),
179        }
180    }
181}
182
183impl Error for SqlTypeError {}
184
185fn validate_type_label(input: &str) -> Result<&str, SqlTypeError> {
186    let trimmed = input.trim();
187    if trimmed.is_empty() {
188        return Err(SqlTypeError::Empty);
189    }
190    if trimmed.chars().any(char::is_control) {
191        return Err(SqlTypeError::ControlCharacter);
192    }
193    Ok(trimmed)
194}
195
196fn normalized_type_label(input: &str) -> Result<String, SqlTypeError> {
197    let trimmed = validate_type_label(input)?;
198    Ok(trimmed
199        .replace('_', " ")
200        .split_whitespace()
201        .collect::<Vec<_>>()
202        .join(" ")
203        .to_ascii_lowercase())
204}
205
206#[cfg(test)]
207mod tests {
208    use super::{SqlScalarType, SqlTypeError, SqlTypeModifier, SqlTypeName};
209
210    #[test]
211    fn parses_common_scalar_types() -> Result<(), SqlTypeError> {
212        assert_eq!("varchar".parse::<SqlScalarType>()?, SqlScalarType::Text);
213        assert_eq!("numeric".parse::<SqlScalarType>()?, SqlScalarType::Decimal);
214        assert_eq!("blob".parse::<SqlScalarType>()?, SqlScalarType::Binary);
215        Ok(())
216    }
217
218    #[test]
219    fn validates_type_names() -> Result<(), SqlTypeError> {
220        let name = SqlTypeName::new(" NUMERIC ")?;
221        assert_eq!(name.as_str(), "NUMERIC");
222        assert_eq!(
223            SqlTypeName::from_scalar(SqlScalarType::Uuid).to_string(),
224            "uuid"
225        );
226        assert_eq!(SqlTypeName::new(""), Err(SqlTypeError::Empty));
227        Ok(())
228    }
229
230    #[test]
231    fn renders_modifiers() {
232        assert_eq!(SqlTypeModifier::NotNull.to_string(), "NOT NULL");
233        assert_eq!(
234            SqlTypeModifier::Precision {
235                precision: 10,
236                scale: Some(2)
237            }
238            .to_string(),
239            "(10, 2)"
240        );
241    }
242}