Skip to main content

use_sql_value/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7/// SQL null literal marker.
8#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
9pub struct SqlNull;
10
11impl fmt::Display for SqlNull {
12    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
13        formatter.write_str("NULL")
14    }
15}
16
17/// SQL string literal text.
18#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
19pub struct SqlStringLiteral(String);
20
21impl SqlStringLiteral {
22    /// Creates a string literal from text.
23    #[must_use]
24    pub fn new(input: impl Into<String>) -> Self {
25        Self(input.into())
26    }
27
28    /// Returns the unescaped literal text.
29    #[must_use]
30    pub fn as_str(&self) -> &str {
31        &self.0
32    }
33}
34
35impl AsRef<str> for SqlStringLiteral {
36    fn as_ref(&self) -> &str {
37        self.as_str()
38    }
39}
40
41impl fmt::Display for SqlStringLiteral {
42    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
43        formatter.write_str("'")?;
44        for character in self.0.chars() {
45            if character == '\'' {
46                formatter.write_str("'")?;
47            }
48            write!(formatter, "{character}")?;
49        }
50        formatter.write_str("'")
51    }
52}
53
54/// SQL number literal text.
55#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
56pub struct SqlNumberLiteral(String);
57
58impl SqlNumberLiteral {
59    /// Creates a conservatively validated number literal.
60    ///
61    /// # Errors
62    ///
63    /// Returns [`SqlValueError`] when the number is empty or not a finite numeric literal.
64    pub fn new(input: impl AsRef<str>) -> Result<Self, SqlValueError> {
65        let trimmed = input.as_ref().trim();
66        if trimmed.is_empty() {
67            return Err(SqlValueError::EmptyNumber);
68        }
69        if !trimmed.chars().any(|character| character.is_ascii_digit()) {
70            return Err(SqlValueError::InvalidNumber);
71        }
72        let value = trimmed
73            .parse::<f64>()
74            .map_err(|_| SqlValueError::InvalidNumber)?;
75        if !value.is_finite() {
76            return Err(SqlValueError::InvalidNumber);
77        }
78        Ok(Self(trimmed.to_owned()))
79    }
80
81    /// Returns the stored number literal text.
82    #[must_use]
83    pub fn as_str(&self) -> &str {
84        &self.0
85    }
86}
87
88impl AsRef<str> for SqlNumberLiteral {
89    fn as_ref(&self) -> &str {
90        self.as_str()
91    }
92}
93
94impl fmt::Display for SqlNumberLiteral {
95    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
96        formatter.write_str(self.as_str())
97    }
98}
99
100impl FromStr for SqlNumberLiteral {
101    type Err = SqlValueError;
102
103    fn from_str(input: &str) -> Result<Self, Self::Err> {
104        Self::new(input)
105    }
106}
107
108impl TryFrom<&str> for SqlNumberLiteral {
109    type Error = SqlValueError;
110
111    fn try_from(value: &str) -> Result<Self, Self::Error> {
112        Self::new(value)
113    }
114}
115
116/// SQL boolean literal.
117#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
118pub struct SqlBooleanLiteral(bool);
119
120impl SqlBooleanLiteral {
121    /// Creates a boolean literal.
122    #[must_use]
123    pub const fn new(value: bool) -> Self {
124        Self(value)
125    }
126
127    /// Returns the stored boolean value.
128    #[must_use]
129    pub const fn value(self) -> bool {
130        self.0
131    }
132}
133
134impl fmt::Display for SqlBooleanLiteral {
135    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
136        formatter.write_str(if self.0 { "TRUE" } else { "FALSE" })
137    }
138}
139
140/// Simple SQL literal/value primitives.
141#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
142pub enum SqlValue {
143    Null(SqlNull),
144    String(SqlStringLiteral),
145    Number(SqlNumberLiteral),
146    Boolean(SqlBooleanLiteral),
147}
148
149impl SqlValue {
150    /// Returns a null value.
151    #[must_use]
152    pub const fn null() -> Self {
153        Self::Null(SqlNull)
154    }
155
156    /// Returns a string literal value.
157    #[must_use]
158    pub fn string(input: impl Into<String>) -> Self {
159        Self::String(SqlStringLiteral::new(input))
160    }
161
162    /// Returns a number literal value.
163    ///
164    /// # Errors
165    ///
166    /// Returns [`SqlValueError`] when number validation fails.
167    pub fn number(input: impl AsRef<str>) -> Result<Self, SqlValueError> {
168        SqlNumberLiteral::new(input).map(Self::Number)
169    }
170
171    /// Returns a boolean literal value.
172    #[must_use]
173    pub const fn boolean(value: bool) -> Self {
174        Self::Boolean(SqlBooleanLiteral::new(value))
175    }
176}
177
178impl Default for SqlValue {
179    fn default() -> Self {
180        Self::null()
181    }
182}
183
184impl fmt::Display for SqlValue {
185    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
186        match self {
187            Self::Null(value) => value.fmt(formatter),
188            Self::String(value) => value.fmt(formatter),
189            Self::Number(value) => value.fmt(formatter),
190            Self::Boolean(value) => value.fmt(formatter),
191        }
192    }
193}
194
195/// Error returned when SQL literal values are invalid.
196#[derive(Clone, Copy, Debug, Eq, PartialEq)]
197pub enum SqlValueError {
198    EmptyNumber,
199    InvalidNumber,
200}
201
202impl fmt::Display for SqlValueError {
203    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
204        match self {
205            Self::EmptyNumber => formatter.write_str("SQL number literal cannot be empty"),
206            Self::InvalidNumber => formatter.write_str("invalid SQL number literal"),
207        }
208    }
209}
210
211impl Error for SqlValueError {}
212
213#[cfg(test)]
214mod tests {
215    use super::{SqlBooleanLiteral, SqlNumberLiteral, SqlValue, SqlValueError};
216
217    #[test]
218    fn renders_simple_literals() -> Result<(), SqlValueError> {
219        assert_eq!(SqlValue::null().to_string(), "NULL");
220        assert_eq!(SqlValue::string("Ada's").to_string(), "'Ada''s'");
221        assert_eq!(SqlValue::number("42.5")?.to_string(), "42.5");
222        assert_eq!(SqlBooleanLiteral::new(true).to_string(), "TRUE");
223        Ok(())
224    }
225
226    #[test]
227    fn validates_number_literals() {
228        assert_eq!(SqlNumberLiteral::new(""), Err(SqlValueError::EmptyNumber));
229        assert_eq!(
230            SqlNumberLiteral::new("NaN"),
231            Err(SqlValueError::InvalidNumber)
232        );
233        assert_eq!(
234            SqlNumberLiteral::new("1e999"),
235            Err(SqlValueError::InvalidNumber)
236        );
237    }
238}