1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
9pub struct SqlIdentifier(String);
10
11impl SqlIdentifier {
12 pub fn new(input: impl AsRef<str>) -> Result<Self, SqlIdentifierError> {
19 validate_identifier_text(input.as_ref()).map(|value| Self(value.to_owned()))
20 }
21
22 #[must_use]
24 pub fn as_str(&self) -> &str {
25 &self.0
26 }
27
28 #[must_use]
30 pub fn into_string(self) -> String {
31 self.0
32 }
33}
34
35impl AsRef<str> for SqlIdentifier {
36 fn as_ref(&self) -> &str {
37 self.as_str()
38 }
39}
40
41impl fmt::Display for SqlIdentifier {
42 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
43 formatter.write_str(self.as_str())
44 }
45}
46
47impl FromStr for SqlIdentifier {
48 type Err = SqlIdentifierError;
49
50 fn from_str(input: &str) -> Result<Self, Self::Err> {
51 Self::new(input)
52 }
53}
54
55impl TryFrom<&str> for SqlIdentifier {
56 type Error = SqlIdentifierError;
57
58 fn try_from(value: &str) -> Result<Self, Self::Error> {
59 Self::new(value)
60 }
61}
62
63#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
65pub struct SqlQualifiedName {
66 parts: Vec<SqlIdentifier>,
67}
68
69impl SqlQualifiedName {
70 pub fn new(parts: Vec<SqlIdentifier>) -> Result<Self, SqlIdentifierError> {
76 if parts.is_empty() {
77 return Err(SqlIdentifierError::EmptyQualifiedName);
78 }
79
80 Ok(Self { parts })
81 }
82
83 pub fn parse(input: &str) -> Result<Self, SqlIdentifierError> {
89 let trimmed = input.trim();
90 if trimmed.is_empty() {
91 return Err(SqlIdentifierError::EmptyQualifiedName);
92 }
93
94 let parts = trimmed
95 .split('.')
96 .map(SqlIdentifier::new)
97 .collect::<Result<Vec<_>, _>>()?;
98 Self::new(parts)
99 }
100
101 #[must_use]
103 pub fn parts(&self) -> &[SqlIdentifier] {
104 &self.parts
105 }
106}
107
108impl fmt::Display for SqlQualifiedName {
109 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
110 let mut parts = self.parts.iter();
111 if let Some(first) = parts.next() {
112 write!(formatter, "{first}")?;
113 }
114 for part in parts {
115 write!(formatter, ".{part}")?;
116 }
117 Ok(())
118 }
119}
120
121impl FromStr for SqlQualifiedName {
122 type Err = SqlIdentifierError;
123
124 fn from_str(input: &str) -> Result<Self, Self::Err> {
125 Self::parse(input)
126 }
127}
128
129impl TryFrom<&str> for SqlQualifiedName {
130 type Error = SqlIdentifierError;
131
132 fn try_from(value: &str) -> Result<Self, Self::Error> {
133 Self::parse(value)
134 }
135}
136
137#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
139pub struct SqlAlias(SqlIdentifier);
140
141impl SqlAlias {
142 pub fn new(input: impl AsRef<str>) -> Result<Self, SqlIdentifierError> {
148 SqlIdentifier::new(input).map(Self)
149 }
150
151 #[must_use]
153 pub const fn identifier(&self) -> &SqlIdentifier {
154 &self.0
155 }
156
157 #[must_use]
159 pub fn as_str(&self) -> &str {
160 self.0.as_str()
161 }
162}
163
164impl AsRef<str> for SqlAlias {
165 fn as_ref(&self) -> &str {
166 self.as_str()
167 }
168}
169
170impl fmt::Display for SqlAlias {
171 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
172 self.0.fmt(formatter)
173 }
174}
175
176impl FromStr for SqlAlias {
177 type Err = SqlIdentifierError;
178
179 fn from_str(input: &str) -> Result<Self, Self::Err> {
180 Self::new(input)
181 }
182}
183
184impl TryFrom<&str> for SqlAlias {
185 type Error = SqlIdentifierError;
186
187 fn try_from(value: &str) -> Result<Self, Self::Error> {
188 Self::new(value)
189 }
190}
191
192#[derive(Clone, Copy, Debug, Eq, PartialEq)]
194pub enum SqlIdentifierError {
195 Empty,
197 ContainsDot,
199 EmptyQualifiedName,
201 ControlCharacter {
203 index: usize,
205 character: char,
207 },
208}
209
210impl fmt::Display for SqlIdentifierError {
211 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
212 match self {
213 Self::Empty => formatter.write_str("SQL identifier cannot be empty"),
214 Self::ContainsDot => formatter.write_str("SQL identifier segment cannot contain a dot"),
215 Self::EmptyQualifiedName => formatter.write_str("SQL qualified name cannot be empty"),
216 Self::ControlCharacter { index, character } => write!(
217 formatter,
218 "SQL identifier contains control character {character:?} at byte index {index}"
219 ),
220 }
221 }
222}
223
224impl Error for SqlIdentifierError {}
225
226#[must_use]
228pub fn is_valid_unquoted_ident(input: &str) -> bool {
229 validate_unquoted_ident(input).is_ok()
230}
231
232#[must_use]
234pub fn needs_quoting(input: &str) -> bool {
235 !is_valid_unquoted_ident(input)
236}
237
238#[must_use]
240pub fn quote_ident(input: &str) -> String {
241 let trimmed = input.trim();
242 let mut quoted = String::with_capacity(trimmed.len() + 2);
243 quoted.push('"');
244 for character in trimmed.chars() {
245 if character == '"' {
246 quoted.push('"');
247 }
248 quoted.push(character);
249 }
250 quoted.push('"');
251 quoted
252}
253
254#[must_use]
256pub fn normalize_ident(input: &str) -> String {
257 let trimmed = input.trim();
258 if is_valid_unquoted_ident(trimmed) {
259 trimmed.to_ascii_lowercase()
260 } else {
261 quote_ident(trimmed)
262 }
263}
264
265fn validate_identifier_text(input: &str) -> Result<&str, SqlIdentifierError> {
266 let trimmed = input.trim();
267 if trimmed.is_empty() {
268 return Err(SqlIdentifierError::Empty);
269 }
270 if trimmed.contains('.') {
271 return Err(SqlIdentifierError::ContainsDot);
272 }
273 if let Some((index, character)) = trimmed
274 .char_indices()
275 .find(|(_, character)| character.is_control())
276 {
277 return Err(SqlIdentifierError::ControlCharacter { index, character });
278 }
279 Ok(trimmed)
280}
281
282fn validate_unquoted_ident(input: &str) -> Result<(), SqlIdentifierError> {
283 let trimmed = validate_identifier_text(input)?;
284 let mut characters = trimmed.chars();
285 let Some(first) = characters.next() else {
286 return Err(SqlIdentifierError::Empty);
287 };
288 if !(first == '_' || first.is_ascii_alphabetic()) {
289 return Err(SqlIdentifierError::Empty);
290 }
291 if !characters.all(|character| character == '_' || character.is_ascii_alphanumeric()) {
292 return Err(SqlIdentifierError::Empty);
293 }
294 if is_reserved_like(trimmed) {
295 return Err(SqlIdentifierError::Empty);
296 }
297 Ok(())
298}
299
300fn is_reserved_like(input: &str) -> bool {
301 matches!(
302 input.trim().to_ascii_uppercase().as_str(),
303 "SELECT"
304 | "INSERT"
305 | "UPDATE"
306 | "DELETE"
307 | "CREATE"
308 | "ALTER"
309 | "DROP"
310 | "TABLE"
311 | "VIEW"
312 | "INDEX"
313 | "WHERE"
314 | "FROM"
315 | "JOIN"
316 | "GROUP"
317 | "ORDER"
318 | "LIMIT"
319 | "OFFSET"
320 | "RETURNING"
321 | "PRIMARY"
322 | "FOREIGN"
323 | "KEY"
324 | "UNIQUE"
325 | "NOT"
326 | "NULL"
327 | "CHECK"
328 | "DEFAULT"
329 )
330}
331
332#[cfg(test)]
333mod tests {
334 use super::{
335 SqlIdentifier, SqlIdentifierError, SqlQualifiedName, is_valid_unquoted_ident,
336 needs_quoting, normalize_ident, quote_ident,
337 };
338
339 #[test]
340 fn validates_identifier_text() -> Result<(), SqlIdentifierError> {
341 let identifier = SqlIdentifier::new(" users ")?;
342 assert_eq!(identifier.as_str(), "users");
343 assert_eq!(SqlIdentifier::new(""), Err(SqlIdentifierError::Empty));
344 assert_eq!(
345 SqlIdentifier::new("public.users"),
346 Err(SqlIdentifierError::ContainsDot)
347 );
348 Ok(())
349 }
350
351 #[test]
352 fn checks_unquoted_identifiers() {
353 assert!(is_valid_unquoted_ident("users_1"));
354 assert!(!is_valid_unquoted_ident("1users"));
355 assert!(!is_valid_unquoted_ident("select"));
356 assert!(needs_quoting("order items"));
357 }
358
359 #[test]
360 fn quotes_and_normalizes_identifiers() {
361 assert_eq!(quote_ident("user\"name"), "\"user\"\"name\"");
362 assert_eq!(normalize_ident("Users"), "users");
363 assert_eq!(normalize_ident("select"), "\"select\"");
364 }
365
366 #[test]
367 fn parses_qualified_names() -> Result<(), SqlIdentifierError> {
368 let qualified = SqlQualifiedName::parse("public.users")?;
369 assert_eq!(qualified.parts().len(), 2);
370 assert_eq!(qualified.to_string(), "public.users");
371 assert!(SqlQualifiedName::parse("public.").is_err());
372 Ok(())
373 }
374}