1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7use use_sql_ident::{SqlIdentifier, SqlIdentifierError, is_valid_unquoted_ident};
8
9#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
11pub enum SqlParameterStyle {
12 #[default]
13 PostgresIndexed,
14 PositionalQuestionMark,
15 NamedColon,
16 NamedAtSign,
17}
18
19#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
21pub struct SqlParameterIndex(u32);
22
23impl SqlParameterIndex {
24 pub const fn new(index: u32) -> Result<Self, SqlParameterError> {
30 if index == 0 {
31 Err(SqlParameterError::ZeroIndex)
32 } else {
33 Ok(Self(index))
34 }
35 }
36
37 #[must_use]
39 pub const fn get(self) -> u32 {
40 self.0
41 }
42}
43
44impl fmt::Display for SqlParameterIndex {
45 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
46 write!(formatter, "{}", self.0)
47 }
48}
49
50#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
52pub struct SqlParameterName(SqlIdentifier);
53
54impl SqlParameterName {
55 pub fn new(input: impl AsRef<str>) -> Result<Self, SqlParameterError> {
62 let input = input.as_ref();
63 if !is_valid_unquoted_ident(input) {
64 return Err(SqlParameterError::InvalidName);
65 }
66
67 SqlIdentifier::new(input)
68 .map(Self)
69 .map_err(SqlParameterError::Identifier)
70 }
71
72 #[must_use]
74 pub fn as_str(&self) -> &str {
75 self.0.as_str()
76 }
77}
78
79impl AsRef<str> for SqlParameterName {
80 fn as_ref(&self) -> &str {
81 self.as_str()
82 }
83}
84
85impl fmt::Display for SqlParameterName {
86 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
87 self.0.fmt(formatter)
88 }
89}
90
91impl FromStr for SqlParameterName {
92 type Err = SqlParameterError;
93
94 fn from_str(input: &str) -> Result<Self, Self::Err> {
95 Self::new(input)
96 }
97}
98
99#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
101pub enum SqlParameter {
102 PostgresIndexed(SqlParameterIndex),
103 PositionalQuestionMark,
104 NamedColon(SqlParameterName),
105 NamedAtSign(SqlParameterName),
106}
107
108impl SqlParameter {
109 pub const fn postgres_indexed(index: u32) -> Result<Self, SqlParameterError> {
115 match SqlParameterIndex::new(index) {
116 Ok(index) => Ok(Self::PostgresIndexed(index)),
117 Err(error) => Err(error),
118 }
119 }
120
121 #[must_use]
123 pub const fn positional() -> Self {
124 Self::PositionalQuestionMark
125 }
126
127 pub fn named_colon(name: impl AsRef<str>) -> Result<Self, SqlParameterError> {
133 SqlParameterName::new(name).map(Self::NamedColon)
134 }
135
136 pub fn named_at(name: impl AsRef<str>) -> Result<Self, SqlParameterError> {
142 SqlParameterName::new(name).map(Self::NamedAtSign)
143 }
144
145 #[must_use]
147 pub const fn style(&self) -> SqlParameterStyle {
148 match self {
149 Self::PostgresIndexed(_) => SqlParameterStyle::PostgresIndexed,
150 Self::PositionalQuestionMark => SqlParameterStyle::PositionalQuestionMark,
151 Self::NamedColon(_) => SqlParameterStyle::NamedColon,
152 Self::NamedAtSign(_) => SqlParameterStyle::NamedAtSign,
153 }
154 }
155}
156
157impl fmt::Display for SqlParameter {
158 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
159 match self {
160 Self::PostgresIndexed(index) => write!(formatter, "${index}"),
161 Self::PositionalQuestionMark => formatter.write_str("?"),
162 Self::NamedColon(name) => write!(formatter, ":{name}"),
163 Self::NamedAtSign(name) => write!(formatter, "@{name}"),
164 }
165 }
166}
167
168impl FromStr for SqlParameter {
169 type Err = SqlParameterError;
170
171 fn from_str(input: &str) -> Result<Self, Self::Err> {
172 let trimmed = input.trim();
173 if trimmed.is_empty() {
174 return Err(SqlParameterError::Empty);
175 }
176 if trimmed == "?" {
177 return Ok(Self::positional());
178 }
179 if let Some(index) = trimmed.strip_prefix('$') {
180 if index.is_empty() || !index.chars().all(|character| character.is_ascii_digit()) {
181 return Err(SqlParameterError::InvalidIndexed);
182 }
183 let index = index
184 .parse::<u32>()
185 .map_err(|_| SqlParameterError::InvalidIndexed)?;
186 return Self::postgres_indexed(index);
187 }
188 if let Some(name) = trimmed.strip_prefix(':') {
189 return Self::named_colon(name);
190 }
191 if let Some(name) = trimmed.strip_prefix('@') {
192 return Self::named_at(name);
193 }
194 Err(SqlParameterError::UnknownStyle)
195 }
196}
197
198impl TryFrom<&str> for SqlParameter {
199 type Error = SqlParameterError;
200
201 fn try_from(value: &str) -> Result<Self, Self::Error> {
202 value.parse()
203 }
204}
205
206#[derive(Clone, Debug, Eq, PartialEq)]
208pub enum SqlParameterError {
209 Empty,
210 ZeroIndex,
211 InvalidIndexed,
212 InvalidName,
213 UnknownStyle,
214 Identifier(SqlIdentifierError),
215}
216
217impl fmt::Display for SqlParameterError {
218 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
219 match self {
220 Self::Empty => formatter.write_str("SQL parameter placeholder cannot be empty"),
221 Self::ZeroIndex => formatter.write_str("SQL parameter indexes are one-based"),
222 Self::InvalidIndexed => formatter.write_str("invalid indexed SQL parameter"),
223 Self::InvalidName => formatter.write_str("invalid SQL parameter name"),
224 Self::UnknownStyle => formatter.write_str("unknown SQL parameter placeholder style"),
225 Self::Identifier(error) => {
226 write!(formatter, "invalid SQL parameter identifier: {error}")
227 },
228 }
229 }
230}
231
232impl Error for SqlParameterError {}
233
234#[cfg(test)]
235mod tests {
236 use super::{SqlParameter, SqlParameterError, SqlParameterStyle};
237
238 #[test]
239 fn parses_parameter_styles() -> Result<(), SqlParameterError> {
240 assert_eq!("$1".parse::<SqlParameter>()?.to_string(), "$1");
241 assert_eq!(
242 "?".parse::<SqlParameter>()?.style(),
243 SqlParameterStyle::PositionalQuestionMark
244 );
245 assert_eq!(":user_id".parse::<SqlParameter>()?.to_string(), ":user_id");
246 assert_eq!("@user_id".parse::<SqlParameter>()?.to_string(), "@user_id");
247 Ok(())
248 }
249
250 #[test]
251 fn rejects_invalid_parameters() {
252 assert_eq!(
253 "$0".parse::<SqlParameter>(),
254 Err(SqlParameterError::ZeroIndex)
255 );
256 assert_eq!(
257 "$abc".parse::<SqlParameter>(),
258 Err(SqlParameterError::InvalidIndexed)
259 );
260 assert_eq!(
261 ":select".parse::<SqlParameter>(),
262 Err(SqlParameterError::InvalidName)
263 );
264 }
265}