1use super::*;
2
3impl SqlDialect {
4 pub fn quote_identifier(self, identifier: &str) -> Result<String, RustAuthError> {
5 let quote = match self {
6 Self::MySql => '`',
7 Self::Postgres | Self::Sqlite => '"',
8 };
9 identifier
10 .split('.')
11 .map(|part| {
12 validate_identifier(self, part)?;
13 Ok(format!("{quote}{part}{quote}"))
14 })
15 .collect::<Result<Vec<_>, _>>()
16 .map(|parts| parts.join("."))
17 }
18
19 pub fn sanitize_identifier(self, identifier: &str) -> Result<String, RustAuthError> {
20 let sanitized = identifier
21 .chars()
22 .map(|character| {
23 if character.is_ascii_alphanumeric() || character == '_' {
24 character
25 } else {
26 '_'
27 }
28 })
29 .collect::<String>();
30 validate_identifier(self, &sanitized)?;
31 Ok(sanitized)
32 }
33
34 pub fn placeholder(self, index: usize) -> String {
35 match self {
36 Self::Postgres => format!("${index}"),
37 Self::MySql | Self::Sqlite => "?".to_owned(),
38 }
39 }
40
41 pub fn where_clause(
42 self,
43 table: &DbTable,
44 clauses: &[Where],
45 ) -> Result<SqlFragment, RustAuthError> {
46 self.where_clause_starting_at(table, clauses, 1)
47 }
48
49 pub fn where_clause_starting_at(
50 self,
51 table: &DbTable,
52 clauses: &[Where],
53 first_placeholder: usize,
54 ) -> Result<SqlFragment, RustAuthError> {
55 if clauses.is_empty() {
56 return Ok(SqlFragment::default());
57 }
58
59 let mut and_clauses = Vec::new();
60 let mut or_clauses = Vec::new();
61 for clause in clauses {
62 match clause.connector {
63 Connector::And => and_clauses.push(clause),
64 Connector::Or => or_clauses.push(clause),
65 }
66 }
67
68 let mut sql = String::from(" WHERE ");
69 let mut parts = Vec::new();
70 let mut params = Vec::new();
71
72 for clause in and_clauses {
73 parts.push(self.clause_sql(table, clause, &mut params, first_placeholder)?);
74 }
75
76 if !or_clauses.is_empty() {
77 let mut or_parts = Vec::new();
78 for clause in or_clauses {
79 or_parts.push(self.clause_sql(table, clause, &mut params, first_placeholder)?);
80 }
81 let or_sql = or_parts.join(" OR ");
82 if parts.is_empty() && or_parts.len() == 1 {
83 parts.push(or_sql);
84 } else {
85 parts.push(format!("({or_sql})"));
86 }
87 }
88
89 sql.push_str(&parts.join(" AND "));
90 Ok(SqlFragment { sql, params })
91 }
92
93 fn clause_sql(
94 self,
95 table: &DbTable,
96 clause: &Where,
97 params: &mut Vec<SqlParam>,
98 first_placeholder: usize,
99 ) -> Result<String, RustAuthError> {
100 let (_, field) = resolve_field(table, &clause.field)?;
101 let column = self.quote_identifier(&field.name)?;
102 if clause.value == DbValue::Null {
103 return Ok(match clause.operator {
104 WhereOperator::Eq => format!("{column} IS NULL"),
105 WhereOperator::Ne => format!("{column} IS NOT NULL"),
106 _ => {
107 return Err(RustAuthError::Adapter(
108 "null only supports Eq and Ne operators".to_owned(),
109 ))
110 }
111 });
112 }
113
114 match clause.operator {
115 WhereOperator::Eq
116 | WhereOperator::Ne
117 | WhereOperator::Lt
118 | WhereOperator::Lte
119 | WhereOperator::Gt
120 | WhereOperator::Gte => {
121 let operator = match clause.operator {
122 WhereOperator::Eq => "=",
123 WhereOperator::Ne => "!=",
124 WhereOperator::Lt => "<",
125 WhereOperator::Lte => "<=",
126 WhereOperator::Gt => ">",
127 WhereOperator::Gte => ">=",
128 _ => {
129 return Err(RustAuthError::Adapter(
130 "unsupported scalar where operator".to_owned(),
131 ));
132 }
133 };
134 let placeholder =
135 self.push_param(params, field, clause.value.clone(), first_placeholder);
136 if clause.mode == WhereMode::Insensitive
137 && field.field_type == DbFieldType::String
138 && matches!(&clause.value, DbValue::String(_))
139 && matches!(clause.operator, WhereOperator::Eq | WhereOperator::Ne)
140 {
141 Ok(format!("LOWER({column}) {operator} LOWER({placeholder})"))
142 } else {
143 Ok(format!("{column} {operator} {placeholder}"))
144 }
145 }
146 WhereOperator::In | WhereOperator::NotIn => {
147 let placeholders =
148 self.push_array_params(params, field, &clause.value, first_placeholder)?;
149 if placeholders.is_empty() {
150 return Ok(if clause.operator == WhereOperator::In {
151 "1 = 0".to_owned()
152 } else {
153 "1 = 1".to_owned()
154 });
155 }
156 let operator = if clause.operator == WhereOperator::In {
157 "IN"
158 } else {
159 "NOT IN"
160 };
161 let placeholders = if clause.mode == WhereMode::Insensitive
162 && field.field_type == DbFieldType::String
163 && matches!(&clause.value, DbValue::StringArray(_))
164 {
165 placeholders
166 .into_iter()
167 .map(|placeholder| format!("LOWER({placeholder})"))
168 .collect::<Vec<_>>()
169 } else {
170 placeholders
171 };
172 let column = if clause.mode == WhereMode::Insensitive
173 && field.field_type == DbFieldType::String
174 && matches!(&clause.value, DbValue::StringArray(_))
175 {
176 format!("LOWER({column})")
177 } else {
178 column
179 };
180 Ok(format!("{column} {operator} ({})", placeholders.join(", ")))
181 }
182 WhereOperator::Contains | WhereOperator::StartsWith | WhereOperator::EndsWith => {
183 let DbValue::String(value) = &clause.value else {
184 return Err(RustAuthError::Adapter(
185 "string pattern operators require string values".to_owned(),
186 ));
187 };
188 let value = escape_like_pattern(value);
189 let pattern = match clause.operator {
190 WhereOperator::Contains => format!("%{value}%"),
191 WhereOperator::StartsWith => format!("{value}%"),
192 WhereOperator::EndsWith => format!("%{value}"),
193 _ => {
194 return Err(RustAuthError::Adapter(
195 "unsupported string pattern where operator".to_owned(),
196 ));
197 }
198 };
199 let placeholder =
200 self.push_param(params, field, DbValue::String(pattern), first_placeholder);
201 if clause.mode == WhereMode::Insensitive {
202 if self == Self::Postgres {
203 Ok(format!(
204 "{column} ILIKE {placeholder} {}",
205 self.like_escape_clause()
206 ))
207 } else {
208 Ok(format!(
209 "LOWER({column}) LIKE LOWER({placeholder}) {}",
210 self.like_escape_clause()
211 ))
212 }
213 } else {
214 Ok(format!(
215 "{column} LIKE {placeholder} {}",
216 self.like_escape_clause()
217 ))
218 }
219 }
220 }
221 }
222
223 fn push_param(
224 &self,
225 params: &mut Vec<SqlParam>,
226 field: &DbField,
227 value: DbValue,
228 first_placeholder: usize,
229 ) -> String {
230 params.push(SqlParam::new(field, value));
231 self.placeholder(first_placeholder + params.len() - 1)
232 }
233
234 fn push_array_params(
235 self,
236 params: &mut Vec<SqlParam>,
237 field: &DbField,
238 value: &DbValue,
239 first_placeholder: usize,
240 ) -> Result<Vec<String>, RustAuthError> {
241 match value {
242 DbValue::StringArray(values) => Ok(values
243 .iter()
244 .map(|value| {
245 self.push_param(
246 params,
247 field,
248 DbValue::String(value.clone()),
249 first_placeholder,
250 )
251 })
252 .collect()),
253 DbValue::NumberArray(values) => Ok(values
254 .iter()
255 .map(|value| {
256 self.push_param(params, field, DbValue::Number(*value), first_placeholder)
257 })
258 .collect()),
259 _ => Err(RustAuthError::Adapter(
260 "IN and NOT IN require array values".to_owned(),
261 )),
262 }
263 }
264
265 pub fn order_limit_offset(
266 self,
267 table: &DbTable,
268 sort_by: Option<&Sort>,
269 limit: Option<usize>,
270 offset: Option<usize>,
271 ) -> Result<String, RustAuthError> {
272 let mut sql = String::new();
273 if let Some(sort) = sort_by {
274 let (_, field) = resolve_field(table, &sort.field)?;
275 let direction = match sort.direction {
276 SortDirection::Asc => "ASC",
277 SortDirection::Desc => "DESC",
278 };
279 sql.push_str(" ORDER BY ");
280 sql.push_str(&self.quote_identifier(&field.name)?);
281 sql.push(' ');
282 sql.push_str(direction);
283 }
284 if let Some(limit) = limit {
285 sql.push_str(" LIMIT ");
286 sql.push_str(&limit.to_string());
287 }
288 if let Some(offset) = offset {
289 sql.push_str(" OFFSET ");
290 sql.push_str(&offset.to_string());
291 }
292 Ok(sql)
293 }
294
295 pub fn column_definition(
296 self,
297 logical_name: &str,
298 field: &DbField,
299 ) -> Result<String, RustAuthError> {
300 let mut parts = vec![
301 self.quote_identifier(&field.name)?,
302 self.sql_type(logical_name, field),
303 ];
304 if logical_name == "id" || field.name == "id" {
305 match (self, field.generated_id) {
306 (Self::Postgres, Some(IdGeneration::Serial)) => {
307 parts.push("GENERATED BY DEFAULT AS IDENTITY".to_owned());
308 }
309 (Self::Postgres, Some(IdGeneration::Uuid)) => {
310 parts.push("DEFAULT pg_catalog.gen_random_uuid()".to_owned());
311 }
312 (Self::MySql, Some(IdGeneration::Serial)) => {
313 parts.push("AUTO_INCREMENT".to_owned());
314 }
315 _ => {}
316 }
317 parts.push("PRIMARY KEY".to_owned());
318 } else {
319 if field.required {
320 parts.push("NOT NULL".to_owned());
321 }
322 if field.unique {
323 parts.push("UNIQUE".to_owned());
324 }
325 }
326 if let Some(foreign_key) = &field.foreign_key {
327 parts.push(format!(
328 "REFERENCES {} ({})",
329 self.quote_identifier(&foreign_key.table)?,
330 self.quote_identifier(&foreign_key.field)?
331 ));
332 parts.push(on_delete_sql(foreign_key.on_delete).to_owned());
333 }
334 Ok(parts.join(" "))
335 }
336
337 pub fn create_table_statement(self, table: &DbTable) -> Result<String, RustAuthError> {
338 let columns = table
339 .fields
340 .iter()
341 .map(|(logical_name, field)| self.column_definition(logical_name, field))
342 .collect::<Result<Vec<_>, _>>()?;
343 let suffix = match self {
344 Self::MySql => " ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci",
345 Self::Postgres | Self::Sqlite => "",
346 };
347 Ok(format!(
348 "CREATE TABLE IF NOT EXISTS {} ({}){}",
349 self.quote_identifier(&table.name)?,
350 columns.join(", "),
351 suffix
352 ))
353 }
354
355 pub fn add_column_statement(
356 self,
357 table: &str,
358 logical_name: &str,
359 field: &DbField,
360 ) -> Result<String, RustAuthError> {
361 Ok(format!(
362 "ALTER TABLE {} ADD COLUMN {}",
363 self.quote_identifier(table)?,
364 self.column_definition(logical_name, field)?,
365 ))
366 }
367
368 pub fn create_index_statement(
369 self,
370 table: &str,
371 column: &str,
372 index: &str,
373 unique: bool,
374 ) -> Result<String, RustAuthError> {
375 let if_not_exists = match self {
376 Self::Postgres | Self::Sqlite => " IF NOT EXISTS",
377 Self::MySql => "",
378 };
379 let unique = if unique { "UNIQUE " } else { "" };
380 Ok(format!(
381 "CREATE {unique}INDEX{} {} ON {} ({})",
382 if_not_exists,
383 self.quote_identifier(index)?,
384 self.quote_identifier(table)?,
385 self.quote_identifier(column)?,
386 ))
387 }
388
389 pub fn sql_type(self, logical_name: &str, field: &DbField) -> String {
390 match self {
391 Self::Postgres => match field.field_type {
392 DbFieldType::String if field.generated_id == Some(IdGeneration::Uuid) => "UUID",
393 DbFieldType::String => "TEXT",
394 DbFieldType::Number => "BIGINT",
395 DbFieldType::Boolean => "BOOLEAN",
396 DbFieldType::Timestamp => "TIMESTAMPTZ",
397 DbFieldType::Json => "JSONB",
398 DbFieldType::StringArray => "TEXT[]",
399 DbFieldType::NumberArray => "BIGINT[]",
400 }
401 .to_owned(),
402 Self::Sqlite => match field.field_type {
403 DbFieldType::Number if field.generated_id == Some(IdGeneration::Serial) => {
404 "INTEGER"
405 }
406 DbFieldType::String
407 | DbFieldType::Timestamp
408 | DbFieldType::Json
409 | DbFieldType::StringArray
410 | DbFieldType::NumberArray => "TEXT",
411 DbFieldType::Number | DbFieldType::Boolean => "INTEGER",
412 }
413 .to_owned(),
414 Self::MySql => match field.field_type {
415 DbFieldType::Number if field.generated_id == Some(IdGeneration::Serial) => "BIGINT",
416 DbFieldType::String
417 if logical_name == "id"
418 || field.unique
419 || field.index
420 || field.foreign_key.is_some() =>
421 {
422 "VARCHAR(255)"
423 }
424 DbFieldType::String => "TEXT",
425 DbFieldType::Number => "BIGINT",
426 DbFieldType::Boolean => "BOOLEAN",
427 DbFieldType::Timestamp => "DATETIME(6)",
428 DbFieldType::Json | DbFieldType::StringArray | DbFieldType::NumberArray => "JSON",
429 }
430 .to_owned(),
431 }
432 }
433
434 pub fn type_matches(self, actual: &str, field: &DbField) -> bool {
435 let actual = normalized_type(actual);
436 match self {
437 Self::Postgres => match field.field_type {
438 DbFieldType::String => {
439 matches!(
440 actual.as_str(),
441 "text" | "character varying" | "varchar" | "uuid"
442 )
443 }
444 DbFieldType::Number => matches!(
445 actual.as_str(),
446 "bigint"
447 | "integer"
448 | "smallint"
449 | "numeric"
450 | "real"
451 | "double precision"
452 | "int8"
453 | "int4"
454 | "int2"
455 ),
456 DbFieldType::Boolean => matches!(actual.as_str(), "boolean" | "bool"),
457 DbFieldType::Timestamp => matches!(
458 actual.as_str(),
459 "timestamp with time zone"
460 | "timestamp without time zone"
461 | "timestamp"
462 | "timestamptz"
463 | "date"
464 ),
465 DbFieldType::Json => matches!(actual.as_str(), "jsonb" | "json"),
466 DbFieldType::StringArray => {
467 matches!(actual.as_str(), "text[]" | "_text" | "_varchar" | "_bpchar")
468 }
469 DbFieldType::NumberArray => matches!(
470 actual.as_str(),
471 "bigint[]" | "integer[]" | "_int8" | "_int4" | "_int2"
472 ),
473 },
474 Self::MySql => match field.field_type {
475 DbFieldType::String => matches!(actual.as_str(), "varchar" | "text" | "uuid"),
476 DbFieldType::Number => matches!(
477 actual.as_str(),
478 "integer" | "int" | "bigint" | "smallint" | "decimal" | "float" | "double"
479 ),
480 DbFieldType::Boolean => matches!(actual.as_str(), "boolean" | "tinyint"),
481 DbFieldType::Timestamp => {
482 matches!(actual.as_str(), "timestamp" | "datetime" | "date")
483 }
484 DbFieldType::Json | DbFieldType::StringArray | DbFieldType::NumberArray => {
485 actual.as_str() == "json"
486 }
487 },
488 Self::Sqlite => match field.field_type {
489 DbFieldType::String
490 | DbFieldType::Timestamp
491 | DbFieldType::Json
492 | DbFieldType::StringArray
493 | DbFieldType::NumberArray => matches!(
494 actual.as_str(),
495 "text" | "varchar" | "character varying" | "nvarchar" | "clob"
496 ),
497 DbFieldType::Number => matches!(
498 actual.as_str(),
499 "integer"
500 | "int"
501 | "bigint"
502 | "smallint"
503 | "tinyint"
504 | "numeric"
505 | "real"
506 | "double"
507 ),
508 DbFieldType::Boolean => matches!(
509 actual.as_str(),
510 "integer" | "int" | "bigint" | "smallint" | "tinyint" | "boolean" | "bool"
511 ),
512 },
513 }
514 }
515}
516
517fn escape_like_pattern(value: &str) -> String {
518 let mut escaped = String::with_capacity(value.len());
519 for character in value.chars() {
520 if matches!(character, '%' | '_' | '\\') {
521 escaped.push('\\');
522 }
523 escaped.push(character);
524 }
525 escaped
526}
527
528fn validate_identifier(dialect: SqlDialect, identifier: &str) -> Result<(), RustAuthError> {
529 let mut chars = identifier.chars();
530 let Some(first) = chars.next() else {
531 return Err(RustAuthError::Adapter(format!(
532 "{} identifier cannot be empty",
533 dialect.name()
534 )));
535 };
536 if !(first.is_ascii_alphabetic() || first == '_') {
537 return Err(invalid_identifier(dialect, identifier));
538 }
539 if chars.any(|character| !(character.is_ascii_alphanumeric() || character == '_')) {
540 return Err(invalid_identifier(dialect, identifier));
541 }
542 Ok(())
543}
544
545fn invalid_identifier(dialect: SqlDialect, identifier: &str) -> RustAuthError {
546 RustAuthError::Adapter(format!(
547 "invalid {} identifier `{identifier}`",
548 dialect.name()
549 ))
550}
551
552impl SqlDialect {
553 fn name(self) -> &'static str {
554 match self {
555 Self::Postgres => "postgres",
556 Self::MySql => "mysql",
557 Self::Sqlite => "sqlite",
558 }
559 }
560
561 fn like_escape_clause(self) -> &'static str {
562 match self {
563 Self::MySql => "ESCAPE '\\\\'",
564 Self::Postgres | Self::Sqlite => "ESCAPE '\\'",
565 }
566 }
567}
568
569fn normalized_type(value: &str) -> String {
570 value
571 .trim()
572 .split_once('(')
573 .map(|(prefix, _)| prefix)
574 .unwrap_or(value)
575 .trim()
576 .to_ascii_lowercase()
577}
578
579fn on_delete_sql(on_delete: OnDelete) -> &'static str {
580 match on_delete {
581 OnDelete::NoAction => "ON DELETE NO ACTION",
582 OnDelete::Restrict => "ON DELETE RESTRICT",
583 OnDelete::Cascade => "ON DELETE CASCADE",
584 OnDelete::SetNull => "ON DELETE SET NULL",
585 OnDelete::SetDefault => "ON DELETE SET DEFAULT",
586 }
587}
588
589#[cfg(test)]
590mod tests {
591 use super::*;
592
593 fn user_table() -> DbTable {
594 let mut fields = IndexMap::new();
595 fields.insert(
596 "email".to_owned(),
597 DbField::new("email", DbFieldType::String),
598 );
599 fields.insert("name".to_owned(), DbField::new("name", DbFieldType::String));
600 DbTable {
601 name: "users".to_owned(),
602 fields,
603 order: None,
604 }
605 }
606
607 #[test]
608 fn where_clause_applies_insensitive_mode_to_eq() -> Result<(), RustAuthError> {
609 let clause =
610 Where::new("email", DbValue::String("ADA@EXAMPLE.COM".to_owned())).insensitive();
611
612 let fragment = SqlDialect::Postgres.where_clause(&user_table(), &[clause])?;
613
614 assert_eq!(fragment.sql, r#" WHERE LOWER("email") = LOWER($1)"#);
615 Ok(())
616 }
617
618 #[test]
619 fn where_clause_applies_insensitive_mode_to_ne() -> Result<(), RustAuthError> {
620 let clause = Where::new("email", DbValue::String("ADA@EXAMPLE.COM".to_owned()))
621 .operator(WhereOperator::Ne)
622 .insensitive();
623
624 let fragment = SqlDialect::Postgres.where_clause(&user_table(), &[clause])?;
625
626 assert_eq!(fragment.sql, r#" WHERE LOWER("email") != LOWER($1)"#);
627 Ok(())
628 }
629
630 #[test]
631 fn where_clause_applies_insensitive_mode_to_in() -> Result<(), RustAuthError> {
632 let clause = Where::new(
633 "email",
634 DbValue::StringArray(vec![
635 "ADA@EXAMPLE.COM".to_owned(),
636 "GRACE@EXAMPLE.COM".to_owned(),
637 ]),
638 )
639 .operator(WhereOperator::In)
640 .insensitive();
641
642 let fragment = SqlDialect::Postgres.where_clause(&user_table(), &[clause])?;
643
644 assert_eq!(
645 fragment.sql,
646 r#" WHERE LOWER("email") IN (LOWER($1), LOWER($2))"#
647 );
648 Ok(())
649 }
650
651 #[test]
652 fn where_clause_applies_insensitive_mode_to_not_in() -> Result<(), RustAuthError> {
653 let clause = Where::new(
654 "email",
655 DbValue::StringArray(vec!["ADA@EXAMPLE.COM".to_owned()]),
656 )
657 .operator(WhereOperator::NotIn)
658 .insensitive();
659
660 let fragment = SqlDialect::Postgres.where_clause(&user_table(), &[clause])?;
661
662 assert_eq!(fragment.sql, r#" WHERE LOWER("email") NOT IN (LOWER($1))"#);
663 Ok(())
664 }
665
666 #[test]
667 fn where_clause_escapes_like_wildcards_for_contains() -> Result<(), RustAuthError> {
668 let clause = Where::new("email", DbValue::String(r"a%b_c\d".to_owned()))
669 .operator(WhereOperator::Contains);
670
671 let fragment = SqlDialect::Postgres.where_clause(&user_table(), &[clause])?;
672
673 assert_eq!(fragment.sql, r#" WHERE "email" LIKE $1 ESCAPE '\'"#);
674 assert_eq!(
675 fragment.params[0].value,
676 DbValue::String(r"%a\%b\_c\\d%".to_owned())
677 );
678 Ok(())
679 }
680
681 #[test]
682 fn where_clause_escapes_like_wildcards_for_starts_with() -> Result<(), RustAuthError> {
683 let clause = Where::new("email", DbValue::String("100%_".to_owned()))
684 .operator(WhereOperator::StartsWith);
685
686 let fragment = SqlDialect::Sqlite.where_clause(&user_table(), &[clause])?;
687
688 assert_eq!(fragment.sql, r#" WHERE "email" LIKE ? ESCAPE '\'"#);
689 assert_eq!(
690 fragment.params[0].value,
691 DbValue::String(r"100\%\_%".to_owned())
692 );
693 Ok(())
694 }
695
696 #[test]
697 fn where_clause_escapes_like_wildcards_for_insensitive_ends_with() -> Result<(), RustAuthError>
698 {
699 let clause = Where::new("email", DbValue::String(r"\_%".to_owned()))
700 .operator(WhereOperator::EndsWith)
701 .insensitive();
702
703 let fragment = SqlDialect::MySql.where_clause(&user_table(), &[clause])?;
704
705 assert_eq!(
706 fragment.sql,
707 " WHERE LOWER(`email`) LIKE LOWER(?) ESCAPE '\\\\'"
708 );
709 assert_eq!(
710 fragment.params[0].value,
711 DbValue::String(r"%\\\_\%".to_owned())
712 );
713 Ok(())
714 }
715}