Skip to main content

use_sql_dialect/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7/// Lightweight SQL dialect labels.
8#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
9pub enum SqlDialect {
10    #[default]
11    Ansi,
12    PostgreSql,
13    SQLite,
14    MySql,
15    MariaDb,
16    SqlServer,
17    Oracle,
18    DuckDb,
19    BigQuery,
20    Snowflake,
21}
22
23impl SqlDialect {
24    /// Returns the stable lowercase dialect label.
25    #[must_use]
26    pub const fn as_str(self) -> &'static str {
27        match self {
28            Self::Ansi => "ansi",
29            Self::PostgreSql => "postgresql",
30            Self::SQLite => "sqlite",
31            Self::MySql => "mysql",
32            Self::MariaDb => "mariadb",
33            Self::SqlServer => "sql-server",
34            Self::Oracle => "oracle",
35            Self::DuckDb => "duckdb",
36            Self::BigQuery => "bigquery",
37            Self::Snowflake => "snowflake",
38        }
39    }
40
41    /// Returns the broad dialect family.
42    #[must_use]
43    pub const fn family(self) -> SqlDialectFamily {
44        match self {
45            Self::Ansi => SqlDialectFamily::Standard,
46            Self::PostgreSql => SqlDialectFamily::PostgreSql,
47            Self::SQLite => SqlDialectFamily::SQLite,
48            Self::MySql | Self::MariaDb => SqlDialectFamily::MySql,
49            Self::SqlServer => SqlDialectFamily::SqlServer,
50            Self::Oracle => SqlDialectFamily::Oracle,
51            Self::DuckDb | Self::BigQuery | Self::Snowflake => SqlDialectFamily::Analytical,
52        }
53    }
54}
55
56impl fmt::Display for SqlDialect {
57    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
58        formatter.write_str(self.as_str())
59    }
60}
61
62impl FromStr for SqlDialect {
63    type Err = SqlDialectParseError;
64
65    fn from_str(input: &str) -> Result<Self, Self::Err> {
66        match normalized_label(input)?.as_str() {
67            "ansi" | "sqlstandard" | "standard" => Ok(Self::Ansi),
68            "postgres" | "postgresql" => Ok(Self::PostgreSql),
69            "sqlite" | "sqlite3" => Ok(Self::SQLite),
70            "mysql" => Ok(Self::MySql),
71            "mariadb" => Ok(Self::MariaDb),
72            "sqlserver" | "mssql" | "tsql" => Ok(Self::SqlServer),
73            "oracle" => Ok(Self::Oracle),
74            "duckdb" => Ok(Self::DuckDb),
75            "bigquery" => Ok(Self::BigQuery),
76            "snowflake" => Ok(Self::Snowflake),
77            _ => Err(SqlDialectParseError::Unknown),
78        }
79    }
80}
81
82impl TryFrom<&str> for SqlDialect {
83    type Error = SqlDialectParseError;
84
85    fn try_from(value: &str) -> Result<Self, Self::Error> {
86        value.parse()
87    }
88}
89
90/// Broad SQL dialect families.
91#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
92pub enum SqlDialectFamily {
93    #[default]
94    Standard,
95    PostgreSql,
96    SQLite,
97    MySql,
98    SqlServer,
99    Oracle,
100    Analytical,
101}
102
103impl SqlDialectFamily {
104    /// Returns the stable lowercase family label.
105    #[must_use]
106    pub const fn as_str(self) -> &'static str {
107        match self {
108            Self::Standard => "standard",
109            Self::PostgreSql => "postgresql",
110            Self::SQLite => "sqlite",
111            Self::MySql => "mysql",
112            Self::SqlServer => "sql-server",
113            Self::Oracle => "oracle",
114            Self::Analytical => "analytical",
115        }
116    }
117}
118
119impl fmt::Display for SqlDialectFamily {
120    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
121        formatter.write_str(self.as_str())
122    }
123}
124
125/// Error returned when parsing SQL dialect labels fails.
126#[derive(Clone, Copy, Debug, Eq, PartialEq)]
127pub enum SqlDialectParseError {
128    Empty,
129    Unknown,
130}
131
132impl fmt::Display for SqlDialectParseError {
133    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
134        match self {
135            Self::Empty => formatter.write_str("SQL dialect label cannot be empty"),
136            Self::Unknown => formatter.write_str("unknown SQL dialect label"),
137        }
138    }
139}
140
141impl Error for SqlDialectParseError {}
142
143fn normalized_label(input: &str) -> Result<String, SqlDialectParseError> {
144    let trimmed = input.trim();
145    if trimmed.is_empty() {
146        return Err(SqlDialectParseError::Empty);
147    }
148
149    Ok(trimmed
150        .chars()
151        .filter(|character| !matches!(character, '-' | '_' | ' '))
152        .collect::<String>()
153        .to_ascii_lowercase())
154}
155
156#[cfg(test)]
157mod tests {
158    use super::{SqlDialect, SqlDialectFamily, SqlDialectParseError};
159
160    #[test]
161    fn parses_common_dialects() -> Result<(), SqlDialectParseError> {
162        assert_eq!("postgres".parse::<SqlDialect>()?, SqlDialect::PostgreSql);
163        assert_eq!("sql server".parse::<SqlDialect>()?, SqlDialect::SqlServer);
164        assert_eq!(SqlDialect::Snowflake.family(), SqlDialectFamily::Analytical);
165        Ok(())
166    }
167
168    #[test]
169    fn rejects_unknown_dialects() {
170        assert_eq!("".parse::<SqlDialect>(), Err(SqlDialectParseError::Empty));
171        assert_eq!(
172            "firebird".parse::<SqlDialect>(),
173            Err(SqlDialectParseError::Unknown)
174        );
175    }
176}