Skip to main content

use_sql_clause/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7/// Common SQL clause labels.
8#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
9pub enum SqlClauseKind {
10    #[default]
11    Select,
12    From,
13    Where,
14    GroupBy,
15    Having,
16    OrderBy,
17    Limit,
18    Offset,
19    Returning,
20}
21
22impl SqlClauseKind {
23    /// Returns the stable clause label.
24    #[must_use]
25    pub const fn as_str(self) -> &'static str {
26        match self {
27            Self::Select => "SELECT",
28            Self::From => "FROM",
29            Self::Where => "WHERE",
30            Self::GroupBy => "GROUP BY",
31            Self::Having => "HAVING",
32            Self::OrderBy => "ORDER BY",
33            Self::Limit => "LIMIT",
34            Self::Offset => "OFFSET",
35            Self::Returning => "RETURNING",
36        }
37    }
38
39    /// Returns the common select-statement clause ordinal.
40    #[must_use]
41    pub const fn ordinal(self) -> u8 {
42        match self {
43            Self::Select => 10,
44            Self::From => 20,
45            Self::Where => 30,
46            Self::GroupBy => 40,
47            Self::Having => 50,
48            Self::OrderBy => 60,
49            Self::Limit => 70,
50            Self::Offset => 80,
51            Self::Returning => 90,
52        }
53    }
54
55    /// Returns whether `self` commonly appears before `other`.
56    #[must_use]
57    pub const fn comes_before(self, other: Self) -> bool {
58        self.ordinal() < other.ordinal()
59    }
60
61    /// Returns whether `self` commonly appears after `other`.
62    #[must_use]
63    pub const fn comes_after(self, other: Self) -> bool {
64        self.ordinal() > other.ordinal()
65    }
66}
67
68impl fmt::Display for SqlClauseKind {
69    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
70        formatter.write_str(self.as_str())
71    }
72}
73
74impl FromStr for SqlClauseKind {
75    type Err = SqlClauseParseError;
76
77    fn from_str(input: &str) -> Result<Self, Self::Err> {
78        match normalized_clause(input)?.as_str() {
79            "SELECT" => Ok(Self::Select),
80            "FROM" => Ok(Self::From),
81            "WHERE" => Ok(Self::Where),
82            "GROUP BY" | "GROUP" => Ok(Self::GroupBy),
83            "HAVING" => Ok(Self::Having),
84            "ORDER BY" | "ORDER" => Ok(Self::OrderBy),
85            "LIMIT" => Ok(Self::Limit),
86            "OFFSET" => Ok(Self::Offset),
87            "RETURNING" => Ok(Self::Returning),
88            _ => Err(SqlClauseParseError::Unknown),
89        }
90    }
91}
92
93/// A SQL clause label with optional raw text metadata.
94#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
95pub struct SqlClause {
96    kind: SqlClauseKind,
97    text: Option<String>,
98}
99
100impl SqlClause {
101    /// Creates a clause label.
102    #[must_use]
103    pub const fn new(kind: SqlClauseKind) -> Self {
104        Self { kind, text: None }
105    }
106
107    /// Attaches raw clause text metadata.
108    #[must_use]
109    pub fn with_text(mut self, text: impl Into<String>) -> Self {
110        self.text = Some(text.into());
111        self
112    }
113
114    /// Returns the clause kind.
115    #[must_use]
116    pub const fn kind(&self) -> SqlClauseKind {
117        self.kind
118    }
119
120    /// Returns optional raw clause text metadata.
121    #[must_use]
122    pub fn text(&self) -> Option<&str> {
123        self.text.as_deref()
124    }
125}
126
127impl fmt::Display for SqlClause {
128    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
129        formatter.write_str(self.kind.as_str())
130    }
131}
132
133/// Helper type for common clause ordering.
134#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
135pub struct SqlClauseOrder;
136
137impl SqlClauseOrder {
138    /// Sorts clause kinds by their common select-statement order.
139    #[must_use]
140    pub fn sort_kinds(mut kinds: Vec<SqlClauseKind>) -> Vec<SqlClauseKind> {
141        kinds.sort_by_key(|kind| kind.ordinal());
142        kinds
143    }
144}
145
146/// Error returned when parsing clause labels fails.
147#[derive(Clone, Copy, Debug, Eq, PartialEq)]
148pub enum SqlClauseParseError {
149    Empty,
150    Unknown,
151}
152
153impl fmt::Display for SqlClauseParseError {
154    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
155        match self {
156            Self::Empty => formatter.write_str("SQL clause label cannot be empty"),
157            Self::Unknown => formatter.write_str("unknown SQL clause label"),
158        }
159    }
160}
161
162impl Error for SqlClauseParseError {}
163
164fn normalized_clause(input: &str) -> Result<String, SqlClauseParseError> {
165    let trimmed = input.trim();
166    if trimmed.is_empty() {
167        return Err(SqlClauseParseError::Empty);
168    }
169    Ok(trimmed
170        .replace('_', " ")
171        .split_whitespace()
172        .collect::<Vec<_>>()
173        .join(" ")
174        .to_ascii_uppercase())
175}
176
177#[cfg(test)]
178mod tests {
179    use super::{SqlClauseKind, SqlClauseOrder, SqlClauseParseError};
180
181    #[test]
182    fn parses_clause_labels() -> Result<(), SqlClauseParseError> {
183        assert_eq!("group by".parse::<SqlClauseKind>()?, SqlClauseKind::GroupBy);
184        assert!(SqlClauseKind::Where.comes_after(SqlClauseKind::From));
185        Ok(())
186    }
187
188    #[test]
189    fn sorts_clause_kinds() {
190        let sorted = SqlClauseOrder::sort_kinds(vec![SqlClauseKind::Where, SqlClauseKind::Select]);
191        assert_eq!(sorted, vec![SqlClauseKind::Select, SqlClauseKind::Where]);
192    }
193}