1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
9pub enum SqlClauseKind {
10 #[default]
11 Select,
12 From,
13 Where,
14 GroupBy,
15 Having,
16 OrderBy,
17 Limit,
18 Offset,
19 Returning,
20}
21
22impl SqlClauseKind {
23 #[must_use]
25 pub const fn as_str(self) -> &'static str {
26 match self {
27 Self::Select => "SELECT",
28 Self::From => "FROM",
29 Self::Where => "WHERE",
30 Self::GroupBy => "GROUP BY",
31 Self::Having => "HAVING",
32 Self::OrderBy => "ORDER BY",
33 Self::Limit => "LIMIT",
34 Self::Offset => "OFFSET",
35 Self::Returning => "RETURNING",
36 }
37 }
38
39 #[must_use]
41 pub const fn ordinal(self) -> u8 {
42 match self {
43 Self::Select => 10,
44 Self::From => 20,
45 Self::Where => 30,
46 Self::GroupBy => 40,
47 Self::Having => 50,
48 Self::OrderBy => 60,
49 Self::Limit => 70,
50 Self::Offset => 80,
51 Self::Returning => 90,
52 }
53 }
54
55 #[must_use]
57 pub const fn comes_before(self, other: Self) -> bool {
58 self.ordinal() < other.ordinal()
59 }
60
61 #[must_use]
63 pub const fn comes_after(self, other: Self) -> bool {
64 self.ordinal() > other.ordinal()
65 }
66}
67
68impl fmt::Display for SqlClauseKind {
69 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
70 formatter.write_str(self.as_str())
71 }
72}
73
74impl FromStr for SqlClauseKind {
75 type Err = SqlClauseParseError;
76
77 fn from_str(input: &str) -> Result<Self, Self::Err> {
78 match normalized_clause(input)?.as_str() {
79 "SELECT" => Ok(Self::Select),
80 "FROM" => Ok(Self::From),
81 "WHERE" => Ok(Self::Where),
82 "GROUP BY" | "GROUP" => Ok(Self::GroupBy),
83 "HAVING" => Ok(Self::Having),
84 "ORDER BY" | "ORDER" => Ok(Self::OrderBy),
85 "LIMIT" => Ok(Self::Limit),
86 "OFFSET" => Ok(Self::Offset),
87 "RETURNING" => Ok(Self::Returning),
88 _ => Err(SqlClauseParseError::Unknown),
89 }
90 }
91}
92
93#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
95pub struct SqlClause {
96 kind: SqlClauseKind,
97 text: Option<String>,
98}
99
100impl SqlClause {
101 #[must_use]
103 pub const fn new(kind: SqlClauseKind) -> Self {
104 Self { kind, text: None }
105 }
106
107 #[must_use]
109 pub fn with_text(mut self, text: impl Into<String>) -> Self {
110 self.text = Some(text.into());
111 self
112 }
113
114 #[must_use]
116 pub const fn kind(&self) -> SqlClauseKind {
117 self.kind
118 }
119
120 #[must_use]
122 pub fn text(&self) -> Option<&str> {
123 self.text.as_deref()
124 }
125}
126
127impl fmt::Display for SqlClause {
128 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
129 formatter.write_str(self.kind.as_str())
130 }
131}
132
133#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
135pub struct SqlClauseOrder;
136
137impl SqlClauseOrder {
138 #[must_use]
140 pub fn sort_kinds(mut kinds: Vec<SqlClauseKind>) -> Vec<SqlClauseKind> {
141 kinds.sort_by_key(|kind| kind.ordinal());
142 kinds
143 }
144}
145
146#[derive(Clone, Copy, Debug, Eq, PartialEq)]
148pub enum SqlClauseParseError {
149 Empty,
150 Unknown,
151}
152
153impl fmt::Display for SqlClauseParseError {
154 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
155 match self {
156 Self::Empty => formatter.write_str("SQL clause label cannot be empty"),
157 Self::Unknown => formatter.write_str("unknown SQL clause label"),
158 }
159 }
160}
161
162impl Error for SqlClauseParseError {}
163
164fn normalized_clause(input: &str) -> Result<String, SqlClauseParseError> {
165 let trimmed = input.trim();
166 if trimmed.is_empty() {
167 return Err(SqlClauseParseError::Empty);
168 }
169 Ok(trimmed
170 .replace('_', " ")
171 .split_whitespace()
172 .collect::<Vec<_>>()
173 .join(" ")
174 .to_ascii_uppercase())
175}
176
177#[cfg(test)]
178mod tests {
179 use super::{SqlClauseKind, SqlClauseOrder, SqlClauseParseError};
180
181 #[test]
182 fn parses_clause_labels() -> Result<(), SqlClauseParseError> {
183 assert_eq!("group by".parse::<SqlClauseKind>()?, SqlClauseKind::GroupBy);
184 assert!(SqlClauseKind::Where.comes_after(SqlClauseKind::From));
185 Ok(())
186 }
187
188 #[test]
189 fn sorts_clause_kinds() {
190 let sorted = SqlClauseOrder::sort_kinds(vec![SqlClauseKind::Where, SqlClauseKind::Select]);
191 assert_eq!(sorted, vec![SqlClauseKind::Select, SqlClauseKind::Where]);
192 }
193}