1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
9pub enum SqlDialect {
10 #[default]
11 Ansi,
12 PostgreSql,
13 SQLite,
14 MySql,
15 MariaDb,
16 SqlServer,
17 Oracle,
18 DuckDb,
19 BigQuery,
20 Snowflake,
21}
22
23impl SqlDialect {
24 #[must_use]
26 pub const fn as_str(self) -> &'static str {
27 match self {
28 Self::Ansi => "ansi",
29 Self::PostgreSql => "postgresql",
30 Self::SQLite => "sqlite",
31 Self::MySql => "mysql",
32 Self::MariaDb => "mariadb",
33 Self::SqlServer => "sql-server",
34 Self::Oracle => "oracle",
35 Self::DuckDb => "duckdb",
36 Self::BigQuery => "bigquery",
37 Self::Snowflake => "snowflake",
38 }
39 }
40
41 #[must_use]
43 pub const fn family(self) -> SqlDialectFamily {
44 match self {
45 Self::Ansi => SqlDialectFamily::Standard,
46 Self::PostgreSql => SqlDialectFamily::PostgreSql,
47 Self::SQLite => SqlDialectFamily::SQLite,
48 Self::MySql | Self::MariaDb => SqlDialectFamily::MySql,
49 Self::SqlServer => SqlDialectFamily::SqlServer,
50 Self::Oracle => SqlDialectFamily::Oracle,
51 Self::DuckDb | Self::BigQuery | Self::Snowflake => SqlDialectFamily::Analytical,
52 }
53 }
54}
55
56impl fmt::Display for SqlDialect {
57 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
58 formatter.write_str(self.as_str())
59 }
60}
61
62impl FromStr for SqlDialect {
63 type Err = SqlDialectParseError;
64
65 fn from_str(input: &str) -> Result<Self, Self::Err> {
66 match normalized_label(input)?.as_str() {
67 "ansi" | "sqlstandard" | "standard" => Ok(Self::Ansi),
68 "postgres" | "postgresql" => Ok(Self::PostgreSql),
69 "sqlite" | "sqlite3" => Ok(Self::SQLite),
70 "mysql" => Ok(Self::MySql),
71 "mariadb" => Ok(Self::MariaDb),
72 "sqlserver" | "mssql" | "tsql" => Ok(Self::SqlServer),
73 "oracle" => Ok(Self::Oracle),
74 "duckdb" => Ok(Self::DuckDb),
75 "bigquery" => Ok(Self::BigQuery),
76 "snowflake" => Ok(Self::Snowflake),
77 _ => Err(SqlDialectParseError::Unknown),
78 }
79 }
80}
81
82impl TryFrom<&str> for SqlDialect {
83 type Error = SqlDialectParseError;
84
85 fn try_from(value: &str) -> Result<Self, Self::Error> {
86 value.parse()
87 }
88}
89
90#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
92pub enum SqlDialectFamily {
93 #[default]
94 Standard,
95 PostgreSql,
96 SQLite,
97 MySql,
98 SqlServer,
99 Oracle,
100 Analytical,
101}
102
103impl SqlDialectFamily {
104 #[must_use]
106 pub const fn as_str(self) -> &'static str {
107 match self {
108 Self::Standard => "standard",
109 Self::PostgreSql => "postgresql",
110 Self::SQLite => "sqlite",
111 Self::MySql => "mysql",
112 Self::SqlServer => "sql-server",
113 Self::Oracle => "oracle",
114 Self::Analytical => "analytical",
115 }
116 }
117}
118
119impl fmt::Display for SqlDialectFamily {
120 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
121 formatter.write_str(self.as_str())
122 }
123}
124
125#[derive(Clone, Copy, Debug, Eq, PartialEq)]
127pub enum SqlDialectParseError {
128 Empty,
129 Unknown,
130}
131
132impl fmt::Display for SqlDialectParseError {
133 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
134 match self {
135 Self::Empty => formatter.write_str("SQL dialect label cannot be empty"),
136 Self::Unknown => formatter.write_str("unknown SQL dialect label"),
137 }
138 }
139}
140
141impl Error for SqlDialectParseError {}
142
143fn normalized_label(input: &str) -> Result<String, SqlDialectParseError> {
144 let trimmed = input.trim();
145 if trimmed.is_empty() {
146 return Err(SqlDialectParseError::Empty);
147 }
148
149 Ok(trimmed
150 .chars()
151 .filter(|character| !matches!(character, '-' | '_' | ' '))
152 .collect::<String>()
153 .to_ascii_lowercase())
154}
155
156#[cfg(test)]
157mod tests {
158 use super::{SqlDialect, SqlDialectFamily, SqlDialectParseError};
159
160 #[test]
161 fn parses_common_dialects() -> Result<(), SqlDialectParseError> {
162 assert_eq!("postgres".parse::<SqlDialect>()?, SqlDialect::PostgreSql);
163 assert_eq!("sql server".parse::<SqlDialect>()?, SqlDialect::SqlServer);
164 assert_eq!(SqlDialect::Snowflake.family(), SqlDialectFamily::Analytical);
165 Ok(())
166 }
167
168 #[test]
169 fn rejects_unknown_dialects() {
170 assert_eq!("".parse::<SqlDialect>(), Err(SqlDialectParseError::Empty));
171 assert_eq!(
172 "firebird".parse::<SqlDialect>(),
173 Err(SqlDialectParseError::Unknown)
174 );
175 }
176}