Skip to main content

rustauth_core/db/sql/
dialect.rs

1use super::*;
2
3impl SqlDialect {
4    pub fn quote_identifier(self, identifier: &str) -> Result<String, RustAuthError> {
5        let quote = match self {
6            Self::MySql => '`',
7            Self::Postgres | Self::Sqlite => '"',
8        };
9        identifier
10            .split('.')
11            .map(|part| {
12                validate_identifier(self, part)?;
13                Ok(format!("{quote}{part}{quote}"))
14            })
15            .collect::<Result<Vec<_>, _>>()
16            .map(|parts| parts.join("."))
17    }
18
19    pub fn sanitize_identifier(self, identifier: &str) -> Result<String, RustAuthError> {
20        let sanitized = identifier
21            .chars()
22            .map(|character| {
23                if character.is_ascii_alphanumeric() || character == '_' {
24                    character
25                } else {
26                    '_'
27                }
28            })
29            .collect::<String>();
30        validate_identifier(self, &sanitized)?;
31        Ok(sanitized)
32    }
33
34    pub fn placeholder(self, index: usize) -> String {
35        match self {
36            Self::Postgres => format!("${index}"),
37            Self::MySql | Self::Sqlite => "?".to_owned(),
38        }
39    }
40
41    pub fn where_clause(
42        self,
43        table: &DbTable,
44        clauses: &[Where],
45    ) -> Result<SqlFragment, RustAuthError> {
46        self.where_clause_starting_at(table, clauses, 1)
47    }
48
49    pub fn where_clause_starting_at(
50        self,
51        table: &DbTable,
52        clauses: &[Where],
53        first_placeholder: usize,
54    ) -> Result<SqlFragment, RustAuthError> {
55        if clauses.is_empty() {
56            return Ok(SqlFragment::default());
57        }
58
59        let mut and_clauses = Vec::new();
60        let mut or_clauses = Vec::new();
61        for clause in clauses {
62            match clause.connector {
63                Connector::And => and_clauses.push(clause),
64                Connector::Or => or_clauses.push(clause),
65            }
66        }
67
68        let mut sql = String::from(" WHERE ");
69        let mut parts = Vec::new();
70        let mut params = Vec::new();
71
72        for clause in and_clauses {
73            parts.push(self.clause_sql(table, clause, &mut params, first_placeholder)?);
74        }
75
76        if !or_clauses.is_empty() {
77            let mut or_parts = Vec::new();
78            for clause in or_clauses {
79                or_parts.push(self.clause_sql(table, clause, &mut params, first_placeholder)?);
80            }
81            let or_sql = or_parts.join(" OR ");
82            if parts.is_empty() && or_parts.len() == 1 {
83                parts.push(or_sql);
84            } else {
85                parts.push(format!("({or_sql})"));
86            }
87        }
88
89        sql.push_str(&parts.join(" AND "));
90        Ok(SqlFragment { sql, params })
91    }
92
93    fn clause_sql(
94        self,
95        table: &DbTable,
96        clause: &Where,
97        params: &mut Vec<SqlParam>,
98        first_placeholder: usize,
99    ) -> Result<String, RustAuthError> {
100        let (_, field) = resolve_field(table, &clause.field)?;
101        let column = self.quote_identifier(&field.name)?;
102        if clause.value == DbValue::Null {
103            return Ok(match clause.operator {
104                WhereOperator::Eq => format!("{column} IS NULL"),
105                WhereOperator::Ne => format!("{column} IS NOT NULL"),
106                _ => {
107                    return Err(RustAuthError::Adapter(
108                        "null only supports Eq and Ne operators".to_owned(),
109                    ))
110                }
111            });
112        }
113
114        match clause.operator {
115            WhereOperator::Eq
116            | WhereOperator::Ne
117            | WhereOperator::Lt
118            | WhereOperator::Lte
119            | WhereOperator::Gt
120            | WhereOperator::Gte => {
121                let operator = match clause.operator {
122                    WhereOperator::Eq => "=",
123                    WhereOperator::Ne => "!=",
124                    WhereOperator::Lt => "<",
125                    WhereOperator::Lte => "<=",
126                    WhereOperator::Gt => ">",
127                    WhereOperator::Gte => ">=",
128                    _ => {
129                        return Err(RustAuthError::Adapter(
130                            "unsupported scalar where operator".to_owned(),
131                        ));
132                    }
133                };
134                let placeholder =
135                    self.push_param(params, field, clause.value.clone(), first_placeholder);
136                if clause.mode == WhereMode::Insensitive
137                    && field.field_type == DbFieldType::String
138                    && matches!(&clause.value, DbValue::String(_))
139                    && matches!(clause.operator, WhereOperator::Eq | WhereOperator::Ne)
140                {
141                    Ok(format!("LOWER({column}) {operator} LOWER({placeholder})"))
142                } else {
143                    Ok(format!("{column} {operator} {placeholder}"))
144                }
145            }
146            WhereOperator::In | WhereOperator::NotIn => {
147                let placeholders =
148                    self.push_array_params(params, field, &clause.value, first_placeholder)?;
149                if placeholders.is_empty() {
150                    return Ok(if clause.operator == WhereOperator::In {
151                        "1 = 0".to_owned()
152                    } else {
153                        "1 = 1".to_owned()
154                    });
155                }
156                let operator = if clause.operator == WhereOperator::In {
157                    "IN"
158                } else {
159                    "NOT IN"
160                };
161                let placeholders = if clause.mode == WhereMode::Insensitive
162                    && field.field_type == DbFieldType::String
163                    && matches!(&clause.value, DbValue::StringArray(_))
164                {
165                    placeholders
166                        .into_iter()
167                        .map(|placeholder| format!("LOWER({placeholder})"))
168                        .collect::<Vec<_>>()
169                } else {
170                    placeholders
171                };
172                let column = if clause.mode == WhereMode::Insensitive
173                    && field.field_type == DbFieldType::String
174                    && matches!(&clause.value, DbValue::StringArray(_))
175                {
176                    format!("LOWER({column})")
177                } else {
178                    column
179                };
180                Ok(format!("{column} {operator} ({})", placeholders.join(", ")))
181            }
182            WhereOperator::Contains | WhereOperator::StartsWith | WhereOperator::EndsWith => {
183                let DbValue::String(value) = &clause.value else {
184                    return Err(RustAuthError::Adapter(
185                        "string pattern operators require string values".to_owned(),
186                    ));
187                };
188                let value = escape_like_pattern(value);
189                let pattern = match clause.operator {
190                    WhereOperator::Contains => format!("%{value}%"),
191                    WhereOperator::StartsWith => format!("{value}%"),
192                    WhereOperator::EndsWith => format!("%{value}"),
193                    _ => {
194                        return Err(RustAuthError::Adapter(
195                            "unsupported string pattern where operator".to_owned(),
196                        ));
197                    }
198                };
199                let placeholder =
200                    self.push_param(params, field, DbValue::String(pattern), first_placeholder);
201                if clause.mode == WhereMode::Insensitive {
202                    if self == Self::Postgres {
203                        Ok(format!(
204                            "{column} ILIKE {placeholder} {}",
205                            self.like_escape_clause()
206                        ))
207                    } else {
208                        Ok(format!(
209                            "LOWER({column}) LIKE LOWER({placeholder}) {}",
210                            self.like_escape_clause()
211                        ))
212                    }
213                } else {
214                    Ok(format!(
215                        "{column} LIKE {placeholder} {}",
216                        self.like_escape_clause()
217                    ))
218                }
219            }
220        }
221    }
222
223    fn push_param(
224        &self,
225        params: &mut Vec<SqlParam>,
226        field: &DbField,
227        value: DbValue,
228        first_placeholder: usize,
229    ) -> String {
230        params.push(SqlParam::new(field, value));
231        self.placeholder(first_placeholder + params.len() - 1)
232    }
233
234    fn push_array_params(
235        self,
236        params: &mut Vec<SqlParam>,
237        field: &DbField,
238        value: &DbValue,
239        first_placeholder: usize,
240    ) -> Result<Vec<String>, RustAuthError> {
241        match value {
242            DbValue::StringArray(values) => Ok(values
243                .iter()
244                .map(|value| {
245                    self.push_param(
246                        params,
247                        field,
248                        DbValue::String(value.clone()),
249                        first_placeholder,
250                    )
251                })
252                .collect()),
253            DbValue::NumberArray(values) => Ok(values
254                .iter()
255                .map(|value| {
256                    self.push_param(params, field, DbValue::Number(*value), first_placeholder)
257                })
258                .collect()),
259            _ => Err(RustAuthError::Adapter(
260                "IN and NOT IN require array values".to_owned(),
261            )),
262        }
263    }
264
265    pub fn order_limit_offset(
266        self,
267        table: &DbTable,
268        sort_by: Option<&Sort>,
269        limit: Option<usize>,
270        offset: Option<usize>,
271    ) -> Result<String, RustAuthError> {
272        let mut sql = String::new();
273        if let Some(sort) = sort_by {
274            let (_, field) = resolve_field(table, &sort.field)?;
275            let direction = match sort.direction {
276                SortDirection::Asc => "ASC",
277                SortDirection::Desc => "DESC",
278            };
279            sql.push_str(" ORDER BY ");
280            sql.push_str(&self.quote_identifier(&field.name)?);
281            sql.push(' ');
282            sql.push_str(direction);
283        }
284        if let Some(limit) = limit {
285            sql.push_str(" LIMIT ");
286            sql.push_str(&limit.to_string());
287        }
288        if let Some(offset) = offset {
289            sql.push_str(" OFFSET ");
290            sql.push_str(&offset.to_string());
291        }
292        Ok(sql)
293    }
294
295    pub fn column_definition(
296        self,
297        logical_name: &str,
298        field: &DbField,
299    ) -> Result<String, RustAuthError> {
300        let mut parts = vec![
301            self.quote_identifier(&field.name)?,
302            self.sql_type(logical_name, field),
303        ];
304        if logical_name == "id" || field.name == "id" {
305            match (self, field.generated_id) {
306                (Self::Postgres, Some(IdGeneration::Serial)) => {
307                    parts.push("GENERATED BY DEFAULT AS IDENTITY".to_owned());
308                }
309                (Self::Postgres, Some(IdGeneration::Uuid)) => {
310                    parts.push("DEFAULT pg_catalog.gen_random_uuid()".to_owned());
311                }
312                (Self::MySql, Some(IdGeneration::Serial)) => {
313                    parts.push("AUTO_INCREMENT".to_owned());
314                }
315                _ => {}
316            }
317            parts.push("PRIMARY KEY".to_owned());
318        } else {
319            if field.required {
320                parts.push("NOT NULL".to_owned());
321            }
322            if field.unique {
323                parts.push("UNIQUE".to_owned());
324            }
325        }
326        if let Some(foreign_key) = &field.foreign_key {
327            parts.push(format!(
328                "REFERENCES {} ({})",
329                self.quote_identifier(&foreign_key.table)?,
330                self.quote_identifier(&foreign_key.field)?
331            ));
332            parts.push(on_delete_sql(foreign_key.on_delete).to_owned());
333        }
334        Ok(parts.join(" "))
335    }
336
337    pub fn create_table_statement(self, table: &DbTable) -> Result<String, RustAuthError> {
338        let columns = table
339            .fields
340            .iter()
341            .map(|(logical_name, field)| self.column_definition(logical_name, field))
342            .collect::<Result<Vec<_>, _>>()?;
343        let suffix = match self {
344            Self::MySql => " ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci",
345            Self::Postgres | Self::Sqlite => "",
346        };
347        Ok(format!(
348            "CREATE TABLE IF NOT EXISTS {} ({}){}",
349            self.quote_identifier(&table.name)?,
350            columns.join(", "),
351            suffix
352        ))
353    }
354
355    pub fn add_column_statement(
356        self,
357        table: &str,
358        logical_name: &str,
359        field: &DbField,
360    ) -> Result<String, RustAuthError> {
361        Ok(format!(
362            "ALTER TABLE {} ADD COLUMN {}",
363            self.quote_identifier(table)?,
364            self.column_definition(logical_name, field)?,
365        ))
366    }
367
368    pub fn create_index_statement(
369        self,
370        table: &str,
371        column: &str,
372        index: &str,
373        unique: bool,
374    ) -> Result<String, RustAuthError> {
375        let if_not_exists = match self {
376            Self::Postgres | Self::Sqlite => " IF NOT EXISTS",
377            Self::MySql => "",
378        };
379        let unique = if unique { "UNIQUE " } else { "" };
380        Ok(format!(
381            "CREATE {unique}INDEX{} {} ON {} ({})",
382            if_not_exists,
383            self.quote_identifier(index)?,
384            self.quote_identifier(table)?,
385            self.quote_identifier(column)?,
386        ))
387    }
388
389    pub fn sql_type(self, logical_name: &str, field: &DbField) -> String {
390        match self {
391            Self::Postgres => match field.field_type {
392                DbFieldType::String if field.generated_id == Some(IdGeneration::Uuid) => "UUID",
393                DbFieldType::String => "TEXT",
394                DbFieldType::Number => "BIGINT",
395                DbFieldType::Boolean => "BOOLEAN",
396                DbFieldType::Timestamp => "TIMESTAMPTZ",
397                DbFieldType::Json => "JSONB",
398                DbFieldType::StringArray => "TEXT[]",
399                DbFieldType::NumberArray => "BIGINT[]",
400            }
401            .to_owned(),
402            Self::Sqlite => match field.field_type {
403                DbFieldType::Number if field.generated_id == Some(IdGeneration::Serial) => {
404                    "INTEGER"
405                }
406                DbFieldType::String
407                | DbFieldType::Timestamp
408                | DbFieldType::Json
409                | DbFieldType::StringArray
410                | DbFieldType::NumberArray => "TEXT",
411                DbFieldType::Number | DbFieldType::Boolean => "INTEGER",
412            }
413            .to_owned(),
414            Self::MySql => match field.field_type {
415                DbFieldType::Number if field.generated_id == Some(IdGeneration::Serial) => "BIGINT",
416                DbFieldType::String
417                    if logical_name == "id"
418                        || field.unique
419                        || field.index
420                        || field.foreign_key.is_some() =>
421                {
422                    "VARCHAR(255)"
423                }
424                DbFieldType::String => "TEXT",
425                DbFieldType::Number => "BIGINT",
426                DbFieldType::Boolean => "BOOLEAN",
427                DbFieldType::Timestamp => "DATETIME(6)",
428                DbFieldType::Json | DbFieldType::StringArray | DbFieldType::NumberArray => "JSON",
429            }
430            .to_owned(),
431        }
432    }
433
434    pub fn type_matches(self, actual: &str, field: &DbField) -> bool {
435        let actual = normalized_type(actual);
436        match self {
437            Self::Postgres => match field.field_type {
438                DbFieldType::String => {
439                    matches!(
440                        actual.as_str(),
441                        "text" | "character varying" | "varchar" | "uuid"
442                    )
443                }
444                DbFieldType::Number => matches!(
445                    actual.as_str(),
446                    "bigint"
447                        | "integer"
448                        | "smallint"
449                        | "numeric"
450                        | "real"
451                        | "double precision"
452                        | "int8"
453                        | "int4"
454                        | "int2"
455                ),
456                DbFieldType::Boolean => matches!(actual.as_str(), "boolean" | "bool"),
457                DbFieldType::Timestamp => matches!(
458                    actual.as_str(),
459                    "timestamp with time zone"
460                        | "timestamp without time zone"
461                        | "timestamp"
462                        | "timestamptz"
463                        | "date"
464                ),
465                DbFieldType::Json => matches!(actual.as_str(), "jsonb" | "json"),
466                DbFieldType::StringArray => {
467                    matches!(actual.as_str(), "text[]" | "_text" | "_varchar" | "_bpchar")
468                }
469                DbFieldType::NumberArray => matches!(
470                    actual.as_str(),
471                    "bigint[]" | "integer[]" | "_int8" | "_int4" | "_int2"
472                ),
473            },
474            Self::MySql => match field.field_type {
475                DbFieldType::String => matches!(actual.as_str(), "varchar" | "text" | "uuid"),
476                DbFieldType::Number => matches!(
477                    actual.as_str(),
478                    "integer" | "int" | "bigint" | "smallint" | "decimal" | "float" | "double"
479                ),
480                DbFieldType::Boolean => matches!(actual.as_str(), "boolean" | "tinyint"),
481                DbFieldType::Timestamp => {
482                    matches!(actual.as_str(), "timestamp" | "datetime" | "date")
483                }
484                DbFieldType::Json | DbFieldType::StringArray | DbFieldType::NumberArray => {
485                    actual.as_str() == "json"
486                }
487            },
488            Self::Sqlite => match field.field_type {
489                DbFieldType::String
490                | DbFieldType::Timestamp
491                | DbFieldType::Json
492                | DbFieldType::StringArray
493                | DbFieldType::NumberArray => matches!(
494                    actual.as_str(),
495                    "text" | "varchar" | "character varying" | "nvarchar" | "clob"
496                ),
497                DbFieldType::Number => matches!(
498                    actual.as_str(),
499                    "integer"
500                        | "int"
501                        | "bigint"
502                        | "smallint"
503                        | "tinyint"
504                        | "numeric"
505                        | "real"
506                        | "double"
507                ),
508                DbFieldType::Boolean => matches!(
509                    actual.as_str(),
510                    "integer" | "int" | "bigint" | "smallint" | "tinyint" | "boolean" | "bool"
511                ),
512            },
513        }
514    }
515}
516
517fn escape_like_pattern(value: &str) -> String {
518    let mut escaped = String::with_capacity(value.len());
519    for character in value.chars() {
520        if matches!(character, '%' | '_' | '\\') {
521            escaped.push('\\');
522        }
523        escaped.push(character);
524    }
525    escaped
526}
527
528fn validate_identifier(dialect: SqlDialect, identifier: &str) -> Result<(), RustAuthError> {
529    let mut chars = identifier.chars();
530    let Some(first) = chars.next() else {
531        return Err(RustAuthError::Adapter(format!(
532            "{} identifier cannot be empty",
533            dialect.name()
534        )));
535    };
536    if !(first.is_ascii_alphabetic() || first == '_') {
537        return Err(invalid_identifier(dialect, identifier));
538    }
539    if chars.any(|character| !(character.is_ascii_alphanumeric() || character == '_')) {
540        return Err(invalid_identifier(dialect, identifier));
541    }
542    Ok(())
543}
544
545fn invalid_identifier(dialect: SqlDialect, identifier: &str) -> RustAuthError {
546    RustAuthError::Adapter(format!(
547        "invalid {} identifier `{identifier}`",
548        dialect.name()
549    ))
550}
551
552impl SqlDialect {
553    fn name(self) -> &'static str {
554        match self {
555            Self::Postgres => "postgres",
556            Self::MySql => "mysql",
557            Self::Sqlite => "sqlite",
558        }
559    }
560
561    fn like_escape_clause(self) -> &'static str {
562        match self {
563            Self::MySql => "ESCAPE '\\\\'",
564            Self::Postgres | Self::Sqlite => "ESCAPE '\\'",
565        }
566    }
567}
568
569fn normalized_type(value: &str) -> String {
570    value
571        .trim()
572        .split_once('(')
573        .map(|(prefix, _)| prefix)
574        .unwrap_or(value)
575        .trim()
576        .to_ascii_lowercase()
577}
578
579fn on_delete_sql(on_delete: OnDelete) -> &'static str {
580    match on_delete {
581        OnDelete::NoAction => "ON DELETE NO ACTION",
582        OnDelete::Restrict => "ON DELETE RESTRICT",
583        OnDelete::Cascade => "ON DELETE CASCADE",
584        OnDelete::SetNull => "ON DELETE SET NULL",
585        OnDelete::SetDefault => "ON DELETE SET DEFAULT",
586    }
587}
588
589#[cfg(test)]
590mod tests {
591    use super::*;
592
593    fn user_table() -> DbTable {
594        let mut fields = IndexMap::new();
595        fields.insert(
596            "email".to_owned(),
597            DbField::new("email", DbFieldType::String),
598        );
599        fields.insert("name".to_owned(), DbField::new("name", DbFieldType::String));
600        DbTable {
601            name: "users".to_owned(),
602            fields,
603            order: None,
604        }
605    }
606
607    #[test]
608    fn where_clause_applies_insensitive_mode_to_eq() -> Result<(), RustAuthError> {
609        let clause =
610            Where::new("email", DbValue::String("ADA@EXAMPLE.COM".to_owned())).insensitive();
611
612        let fragment = SqlDialect::Postgres.where_clause(&user_table(), &[clause])?;
613
614        assert_eq!(fragment.sql, r#" WHERE LOWER("email") = LOWER($1)"#);
615        Ok(())
616    }
617
618    #[test]
619    fn where_clause_applies_insensitive_mode_to_ne() -> Result<(), RustAuthError> {
620        let clause = Where::new("email", DbValue::String("ADA@EXAMPLE.COM".to_owned()))
621            .operator(WhereOperator::Ne)
622            .insensitive();
623
624        let fragment = SqlDialect::Postgres.where_clause(&user_table(), &[clause])?;
625
626        assert_eq!(fragment.sql, r#" WHERE LOWER("email") != LOWER($1)"#);
627        Ok(())
628    }
629
630    #[test]
631    fn where_clause_applies_insensitive_mode_to_in() -> Result<(), RustAuthError> {
632        let clause = Where::new(
633            "email",
634            DbValue::StringArray(vec![
635                "ADA@EXAMPLE.COM".to_owned(),
636                "GRACE@EXAMPLE.COM".to_owned(),
637            ]),
638        )
639        .operator(WhereOperator::In)
640        .insensitive();
641
642        let fragment = SqlDialect::Postgres.where_clause(&user_table(), &[clause])?;
643
644        assert_eq!(
645            fragment.sql,
646            r#" WHERE LOWER("email") IN (LOWER($1), LOWER($2))"#
647        );
648        Ok(())
649    }
650
651    #[test]
652    fn where_clause_applies_insensitive_mode_to_not_in() -> Result<(), RustAuthError> {
653        let clause = Where::new(
654            "email",
655            DbValue::StringArray(vec!["ADA@EXAMPLE.COM".to_owned()]),
656        )
657        .operator(WhereOperator::NotIn)
658        .insensitive();
659
660        let fragment = SqlDialect::Postgres.where_clause(&user_table(), &[clause])?;
661
662        assert_eq!(fragment.sql, r#" WHERE LOWER("email") NOT IN (LOWER($1))"#);
663        Ok(())
664    }
665
666    #[test]
667    fn where_clause_escapes_like_wildcards_for_contains() -> Result<(), RustAuthError> {
668        let clause = Where::new("email", DbValue::String(r"a%b_c\d".to_owned()))
669            .operator(WhereOperator::Contains);
670
671        let fragment = SqlDialect::Postgres.where_clause(&user_table(), &[clause])?;
672
673        assert_eq!(fragment.sql, r#" WHERE "email" LIKE $1 ESCAPE '\'"#);
674        assert_eq!(
675            fragment.params[0].value,
676            DbValue::String(r"%a\%b\_c\\d%".to_owned())
677        );
678        Ok(())
679    }
680
681    #[test]
682    fn where_clause_escapes_like_wildcards_for_starts_with() -> Result<(), RustAuthError> {
683        let clause = Where::new("email", DbValue::String("100%_".to_owned()))
684            .operator(WhereOperator::StartsWith);
685
686        let fragment = SqlDialect::Sqlite.where_clause(&user_table(), &[clause])?;
687
688        assert_eq!(fragment.sql, r#" WHERE "email" LIKE ? ESCAPE '\'"#);
689        assert_eq!(
690            fragment.params[0].value,
691            DbValue::String(r"100\%\_%".to_owned())
692        );
693        Ok(())
694    }
695
696    #[test]
697    fn where_clause_escapes_like_wildcards_for_insensitive_ends_with() -> Result<(), RustAuthError>
698    {
699        let clause = Where::new("email", DbValue::String(r"\_%".to_owned()))
700            .operator(WhereOperator::EndsWith)
701            .insensitive();
702
703        let fragment = SqlDialect::MySql.where_clause(&user_table(), &[clause])?;
704
705        assert_eq!(
706            fragment.sql,
707            " WHERE LOWER(`email`) LIKE LOWER(?) ESCAPE '\\\\'"
708        );
709        assert_eq!(
710            fragment.params[0].value,
711            DbValue::String(r"%\\\_\%".to_owned())
712        );
713        Ok(())
714    }
715}