1use std::collections::HashMap;
4use std::hash::{Hash, Hasher};
5use std::sync::{Arc, LazyLock, Mutex as StdMutex};
6
7use time::OffsetDateTime;
8
9use crate::context::AuthContext;
10use crate::crypto::random::generate_random_string;
11use crate::db::{
12 auth_schema, AuthSchemaOptions, Create, DbAdapter, DbRecord, DbSchema, DbValue, Delete,
13 DeleteMany, FindMany, SchemaTable, Sort, SortDirection, TransactionAdapter, Update,
14 Verification, Where, WhereOperator,
15};
16use crate::error::RustAuthError;
17use crate::options::{SecondaryStorage, StoreIdentifierOption, VerificationOptions};
18use sha2::{Digest, Sha256};
19use tokio::sync::Mutex;
20
21const VERIFICATION_MODEL: &str = "verification";
22const DEFAULT_ID_LENGTH: usize = 32;
23const VERIFICATION_FIELDS: [&str; 6] = [
24 "id",
25 "identifier",
26 "value",
27 "expires_at",
28 "created_at",
29 "updated_at",
30];
31
32fn default_auth_schema() -> &'static DbSchema {
33 static SCHEMA: LazyLock<DbSchema> = LazyLock::new(|| auth_schema(AuthSchemaOptions::default()));
34 &SCHEMA
35}
36
37fn database_verification_store<'a>(
38 context: &'a AuthContext,
39 options: &VerificationOptions,
40) -> Result<DbVerificationStore<'a>, RustAuthError> {
41 if context.secondary_storage().is_some()
42 && context.db_schema.table(VERIFICATION_MODEL).is_none()
43 {
44 return Ok(DbVerificationStore::with_default_schema(
45 context.adapter_ref()?,
46 options.clone(),
47 ));
48 }
49 DbVerificationStore::from_context(context)
50}
51
52pub async fn process_verification_identifier(
54 options: &VerificationOptions,
55 identifier: &str,
56) -> Result<String, RustAuthError> {
57 match options.store_identifier.resolve(identifier) {
58 StoreIdentifierOption::Plain => Ok(identifier.to_owned()),
59 StoreIdentifierOption::Hashed => Ok(hash_verification_identifier(identifier)),
60 StoreIdentifierOption::Custom(hash_fn) => hash_fn(identifier.to_owned()).await,
61 }
62}
63
64fn hash_verification_identifier(identifier: &str) -> String {
65 hex::encode(Sha256::digest(identifier.as_bytes()))
66}
67
68#[derive(Debug, Clone, PartialEq, Eq)]
69pub struct CreateVerificationInput {
70 pub id: Option<String>,
71 pub identifier: String,
72 pub value: String,
73 pub expires_at: OffsetDateTime,
74}
75
76impl CreateVerificationInput {
77 pub fn new(
78 identifier: impl Into<String>,
79 value: impl Into<String>,
80 expires_at: OffsetDateTime,
81 ) -> Self {
82 Self {
83 id: None,
84 identifier: identifier.into(),
85 value: value.into(),
86 expires_at,
87 }
88 }
89
90 #[must_use]
91 pub fn id(mut self, id: impl Into<String>) -> Self {
92 self.id = Some(id.into());
93 self
94 }
95}
96
97#[derive(Debug, Clone, Default, PartialEq, Eq)]
98pub struct UpdateVerificationInput {
99 pub value: Option<String>,
100 pub expires_at: Option<OffsetDateTime>,
101}
102
103impl UpdateVerificationInput {
104 pub fn new() -> Self {
105 Self::default()
106 }
107
108 #[must_use]
109 pub fn value(mut self, value: impl Into<String>) -> Self {
110 self.value = Some(value.into());
111 self
112 }
113
114 #[must_use]
115 pub fn expires_at(mut self, expires_at: OffsetDateTime) -> Self {
116 self.expires_at = Some(expires_at);
117 self
118 }
119}
120
121#[derive(Clone)]
122pub struct DbVerificationStore<'a> {
123 adapter: &'a dyn DbAdapter,
124 schema: DbSchema,
125 options: VerificationOptions,
126}
127
128impl<'a> DbVerificationStore<'a> {
129 pub fn new(adapter: &'a dyn DbAdapter) -> Self {
130 Self::with_options(
131 adapter,
132 default_auth_schema().clone(),
133 VerificationOptions::default(),
134 )
135 }
136
137 pub fn from_context(context: &'a AuthContext) -> Result<Self, RustAuthError> {
138 Ok(Self::with_options(
139 context.adapter_ref()?,
140 context.db_schema.clone(),
141 context.options.verification.clone(),
142 ))
143 }
144
145 pub fn with_options(
146 adapter: &'a dyn DbAdapter,
147 schema: DbSchema,
148 options: VerificationOptions,
149 ) -> Self {
150 Self {
151 adapter,
152 schema,
153 options,
154 }
155 }
156
157 pub fn with_default_schema(adapter: &'a dyn DbAdapter, options: VerificationOptions) -> Self {
158 Self::with_options(adapter, default_auth_schema().clone(), options)
159 }
160
161 pub(super) fn adapter(&self) -> &dyn DbAdapter {
162 self.adapter
163 }
164
165 fn verifications(&self) -> Result<SchemaTable<'_>, RustAuthError> {
166 SchemaTable::new(&self.schema, VERIFICATION_MODEL)
167 }
168
169 fn parse_verification(&self, record: DbRecord) -> Result<Verification, RustAuthError> {
170 verification_from_record(self.verifications()?.map_record(record)?)
171 }
172
173 pub async fn create_verification(
174 &self,
175 input: CreateVerificationInput,
176 ) -> Result<Verification, RustAuthError> {
177 let stored_identifier =
178 process_verification_identifier(&self.options, &input.identifier).await?;
179 let now = OffsetDateTime::now_utc();
180 let id = input
181 .id
182 .unwrap_or_else(|| generate_random_string(DEFAULT_ID_LENGTH));
183
184 let record = self
185 .adapter
186 .create(
187 Create::new(VERIFICATION_MODEL)
188 .data("id", DbValue::String(id))
189 .data("identifier", DbValue::String(stored_identifier))
190 .data("value", DbValue::String(input.value))
191 .data("expires_at", DbValue::Timestamp(input.expires_at))
192 .data("created_at", DbValue::Timestamp(now))
193 .data("updated_at", DbValue::Timestamp(now))
194 .select(VERIFICATION_FIELDS)
195 .force_allow_id(),
196 )
197 .await?;
198
199 self.parse_verification(record)
200 }
201
202 pub async fn find_verification(
203 &self,
204 identifier: &str,
205 ) -> Result<Option<Verification>, RustAuthError> {
206 if !self.options.disable_cleanup {
207 self.delete_expired_verifications().await?;
208 }
209
210 let stored_identifier = process_verification_identifier(&self.options, identifier).await?;
211 let Some(record) = self
212 .adapter
213 .find_many(
214 FindMany::new(VERIFICATION_MODEL)
215 .where_clause(identifier_where(&stored_identifier))
216 .sort_by(Sort::new("created_at", SortDirection::Desc))
217 .limit(1)
218 .select(VERIFICATION_FIELDS),
219 )
220 .await?
221 .into_iter()
222 .next()
223 else {
224 return Ok(None);
225 };
226
227 let verification = self.parse_verification(record)?;
228 if verification.expires_at <= OffsetDateTime::now_utc() {
229 if !self.options.disable_cleanup {
230 self.delete_expired_verifications().await?;
231 }
232 return Ok(None);
233 }
234
235 Ok(Some(verification))
236 }
237
238 pub async fn find_verification_including_expired(
239 &self,
240 identifier: &str,
241 ) -> Result<Option<Verification>, RustAuthError> {
242 let stored_identifier = process_verification_identifier(&self.options, identifier).await?;
243 self.adapter
244 .find_many(
245 FindMany::new(VERIFICATION_MODEL)
246 .where_clause(identifier_where(&stored_identifier))
247 .sort_by(Sort::new("created_at", SortDirection::Desc))
248 .limit(1)
249 .select(VERIFICATION_FIELDS),
250 )
251 .await?
252 .into_iter()
253 .next()
254 .map(|record| self.parse_verification(record))
255 .transpose()
256 }
257
258 pub async fn consume_verification_including_expired(
264 &self,
265 identifier: &str,
266 ) -> Result<Option<Verification>, RustAuthError> {
267 let stored_identifier = process_verification_identifier(&self.options, identifier).await?;
268 let Some(record) = self
269 .adapter
270 .find_many(
271 FindMany::new(VERIFICATION_MODEL)
272 .where_clause(identifier_where(&stored_identifier))
273 .sort_by(Sort::new("created_at", SortDirection::Desc))
274 .limit(1)
275 .select(VERIFICATION_FIELDS),
276 )
277 .await?
278 .into_iter()
279 .next()
280 else {
281 return Ok(None);
282 };
283 let verification = self.parse_verification(record)?;
284 let deleted = self
285 .adapter
286 .delete_many(
287 DeleteMany::new(VERIFICATION_MODEL)
288 .where_clause(identifier_where(&stored_identifier))
289 .where_clause(Where::new("id", DbValue::String(verification.id.clone()))),
290 )
291 .await?;
292 if deleted == 0 {
293 return Ok(None);
294 }
295 Ok(Some(verification))
296 }
297
298 pub async fn compare_and_update_verification_value(
303 &self,
304 identifier: &str,
305 verification_id: &str,
306 expected_value: &str,
307 new_value: String,
308 ) -> Result<Option<Verification>, RustAuthError> {
309 let stored_identifier = process_verification_identifier(&self.options, identifier).await?;
310 self.adapter
311 .update(
312 Update::new(VERIFICATION_MODEL)
313 .where_clause(identifier_where(&stored_identifier))
314 .where_clause(Where::new(
315 "id",
316 DbValue::String(verification_id.to_owned()),
317 ))
318 .where_clause(Where::new(
319 "value",
320 DbValue::String(expected_value.to_owned()),
321 ))
322 .data("value", DbValue::String(new_value))
323 .data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc())),
324 )
325 .await?
326 .map(|record| self.parse_verification(record))
327 .transpose()
328 }
329
330 pub async fn update_verification(
331 &self,
332 identifier: &str,
333 input: UpdateVerificationInput,
334 ) -> Result<Option<Verification>, RustAuthError> {
335 let stored_identifier = process_verification_identifier(&self.options, identifier).await?;
336 let mut query =
337 Update::new(VERIFICATION_MODEL).where_clause(identifier_where(&stored_identifier));
338
339 if let Some(value) = input.value {
340 query = query.data("value", DbValue::String(value));
341 }
342 if let Some(expires_at) = input.expires_at {
343 query = query.data("expires_at", DbValue::Timestamp(expires_at));
344 }
345 query = query.data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc()));
346
347 self.adapter
348 .update(query)
349 .await?
350 .map(|record| self.parse_verification(record))
351 .transpose()
352 }
353
354 pub async fn delete_verification(&self, identifier: &str) -> Result<(), RustAuthError> {
355 let stored_identifier = process_verification_identifier(&self.options, identifier).await?;
356 self.adapter
357 .delete(
358 Delete::new(VERIFICATION_MODEL).where_clause(identifier_where(&stored_identifier)),
359 )
360 .await
361 }
362
363 pub async fn take_verification(
368 &self,
369 identifier: &str,
370 ) -> Result<Option<Verification>, RustAuthError> {
371 let stored_identifier = process_verification_identifier(&self.options, identifier).await?;
372 if self.adapter.capabilities().supports_transactions {
373 let options = self.options.clone();
374 let schema = self.schema.clone();
375 let identifier = identifier.to_owned();
376 let taken = Arc::new(Mutex::new(None));
377 let taken_capture = Arc::clone(&taken);
378 self.adapter
379 .transaction(Box::new(move |transaction: TransactionAdapter<'_>| {
380 let taken = Arc::clone(&taken_capture);
381 let options = options.clone();
382 let schema = schema.clone();
383 let identifier = identifier.clone();
384 Box::pin(async move {
385 let store = DbVerificationStore::with_options(
386 transaction.as_ref(),
387 schema,
388 options,
389 );
390 if let Some(verification) = store.find_verification(&identifier).await? {
391 if verification.expires_at > OffsetDateTime::now_utc() {
392 store.delete_verification(&identifier).await?;
393 *taken.lock().await = Some(verification);
394 }
395 }
396 Ok(())
397 })
398 }))
399 .await?;
400 return Ok(taken.lock().await.take());
401 }
402
403 let take_lock = verification_take_lock(self.adapter, &stored_identifier)?;
404 let _guard = take_lock.lock().await;
405 let Some(verification) = self.find_verification(identifier).await? else {
406 return Ok(None);
407 };
408 self.delete_verification(identifier).await?;
409 Ok(Some(verification))
410 }
411
412 pub async fn take_verification_including_expired(
413 &self,
414 identifier: &str,
415 ) -> Result<Option<Verification>, RustAuthError> {
416 self.consume_verification_including_expired(identifier)
417 .await
418 }
419
420 pub async fn delete_expired_verifications(&self) -> Result<u64, RustAuthError> {
421 if self.options.disable_cleanup {
422 return Ok(0);
423 }
424 self.adapter
425 .delete_many(
426 DeleteMany::new(VERIFICATION_MODEL).where_clause(
427 Where::new("expires_at", DbValue::Timestamp(OffsetDateTime::now_utc()))
428 .operator(WhereOperator::Lt),
429 ),
430 )
431 .await
432 }
433}
434
435#[derive(Clone)]
437pub struct VerificationStore<'a> {
438 database: DbVerificationStore<'a>,
439 secondary_storage: Option<Arc<dyn SecondaryStorage>>,
440 options: VerificationOptions,
441}
442
443impl<'a> VerificationStore<'a> {
444 pub fn new(context: &'a AuthContext) -> Result<Self, RustAuthError> {
445 let options = context.options.verification.clone();
446 Ok(Self {
447 database: database_verification_store(context, &options)?,
448 secondary_storage: context.secondary_storage(),
449 options,
450 })
451 }
452
453 pub async fn create_verification(
454 &self,
455 input: CreateVerificationInput,
456 ) -> Result<Verification, RustAuthError> {
457 let Some(storage) = &self.secondary_storage else {
458 return self.database.create_verification(input).await;
459 };
460 let stored_identifier =
461 process_verification_identifier(&self.options, &input.identifier).await?;
462 let now = OffsetDateTime::now_utc();
463 let verification = Verification {
464 id: input
465 .id
466 .unwrap_or_else(|| generate_random_string(DEFAULT_ID_LENGTH)),
467 identifier: stored_identifier,
468 value: input.value,
469 expires_at: input.expires_at,
470 created_at: now,
471 updated_at: now,
472 };
473 storage
474 .set(
475 &verification_key(&verification.identifier),
476 serialize_verification(&verification)?,
477 ttl_seconds(verification.expires_at),
478 )
479 .await?;
480 Ok(verification)
481 }
482
483 pub async fn find_verification(
484 &self,
485 identifier: &str,
486 ) -> Result<Option<Verification>, RustAuthError> {
487 let Some(storage) = &self.secondary_storage else {
488 return self.database.find_verification(identifier).await;
489 };
490 let stored_identifier = process_verification_identifier(&self.options, identifier).await?;
491 let Some(verification) = self
492 .find_secondary_verification(storage.as_ref(), &stored_identifier)
493 .await?
494 else {
495 return Ok(None);
496 };
497 if verification.expires_at <= OffsetDateTime::now_utc() {
498 storage
499 .delete(&verification_key(&stored_identifier))
500 .await?;
501 return Ok(None);
502 }
503 Ok(Some(verification))
504 }
505
506 pub async fn find_verification_including_expired(
507 &self,
508 identifier: &str,
509 ) -> Result<Option<Verification>, RustAuthError> {
510 let Some(storage) = &self.secondary_storage else {
511 return self
512 .database
513 .find_verification_including_expired(identifier)
514 .await;
515 };
516 let stored_identifier = process_verification_identifier(&self.options, identifier).await?;
517 self.find_secondary_verification(storage.as_ref(), &stored_identifier)
518 .await
519 }
520
521 pub async fn update_verification(
522 &self,
523 identifier: &str,
524 input: UpdateVerificationInput,
525 ) -> Result<Option<Verification>, RustAuthError> {
526 let Some(storage) = &self.secondary_storage else {
527 return self.database.update_verification(identifier, input).await;
528 };
529 let stored_identifier = process_verification_identifier(&self.options, identifier).await?;
530 let Some(mut verification) = self
531 .find_secondary_verification(storage.as_ref(), &stored_identifier)
532 .await?
533 else {
534 return Ok(None);
535 };
536 if let Some(value) = input.value {
537 verification.value = value;
538 }
539 if let Some(expires_at) = input.expires_at {
540 verification.expires_at = expires_at;
541 }
542 verification.updated_at = OffsetDateTime::now_utc();
543 storage
544 .set(
545 &verification_key(&stored_identifier),
546 serialize_verification(&verification)?,
547 ttl_seconds(verification.expires_at),
548 )
549 .await?;
550 Ok(Some(verification))
551 }
552
553 pub async fn delete_verification(&self, identifier: &str) -> Result<(), RustAuthError> {
554 let Some(storage) = &self.secondary_storage else {
555 return self.database.delete_verification(identifier).await;
556 };
557 let stored_identifier = process_verification_identifier(&self.options, identifier).await?;
558 storage.delete(&verification_key(&stored_identifier)).await
559 }
560
561 pub async fn take_verification(
562 &self,
563 identifier: &str,
564 ) -> Result<Option<Verification>, RustAuthError> {
565 let Some(storage) = &self.secondary_storage else {
566 return self.database.take_verification(identifier).await;
567 };
568 let stored_identifier = process_verification_identifier(&self.options, identifier).await?;
569 let Some(raw) = storage.take(&verification_key(&stored_identifier)).await? else {
570 return Ok(None);
571 };
572 let verification = deserialize_verification(&raw)?;
573 if verification.expires_at <= OffsetDateTime::now_utc() {
574 return Ok(None);
575 }
576 Ok(Some(verification))
577 }
578
579 pub async fn take_verification_including_expired(
580 &self,
581 identifier: &str,
582 ) -> Result<Option<Verification>, RustAuthError> {
583 self.consume_verification_including_expired(identifier)
584 .await
585 }
586
587 pub async fn consume_verification_including_expired(
593 &self,
594 identifier: &str,
595 ) -> Result<Option<Verification>, RustAuthError> {
596 let Some(storage) = &self.secondary_storage else {
597 return self
598 .database
599 .consume_verification_including_expired(identifier)
600 .await;
601 };
602 let stored_identifier = process_verification_identifier(&self.options, identifier).await?;
603 let Some(raw) = storage.take(&verification_key(&stored_identifier)).await? else {
604 return Ok(None);
605 };
606 deserialize_verification(&raw).map(Some)
607 }
608
609 pub async fn compare_and_update_verification_value(
611 &self,
612 identifier: &str,
613 verification_id: &str,
614 expected_value: &str,
615 new_value: String,
616 ) -> Result<Option<Verification>, RustAuthError> {
617 let Some(storage) = &self.secondary_storage else {
618 return self
619 .database
620 .compare_and_update_verification_value(
621 identifier,
622 verification_id,
623 expected_value,
624 new_value,
625 )
626 .await;
627 };
628 let stored_identifier = process_verification_identifier(&self.options, identifier).await?;
629 let take_lock = verification_take_lock(self.database.adapter(), &stored_identifier)?;
630 let _guard = take_lock.lock().await;
631 let key = verification_key(&stored_identifier);
632 let Some(raw) = storage.get(&key).await? else {
633 return Ok(None);
634 };
635 let verification = deserialize_verification(&raw)?;
636 if verification.id != verification_id || verification.value != expected_value {
637 return Ok(None);
638 }
639 let mut updated = verification;
640 updated.value = new_value;
641 updated.updated_at = OffsetDateTime::now_utc();
642 storage
643 .set(
644 &key,
645 serialize_verification(&updated)?,
646 ttl_seconds(updated.expires_at),
647 )
648 .await?;
649 Ok(Some(updated))
650 }
651
652 pub async fn delete_expired_verifications(&self) -> Result<u64, RustAuthError> {
653 if self.options.disable_cleanup {
654 return Ok(0);
655 }
656 let Some(_storage) = &self.secondary_storage else {
657 return self.database.delete_expired_verifications().await;
658 };
659 Ok(0)
660 }
661
662 async fn find_secondary_verification(
663 &self,
664 storage: &dyn SecondaryStorage,
665 stored_identifier: &str,
666 ) -> Result<Option<Verification>, RustAuthError> {
667 storage
668 .get(&verification_key(stored_identifier))
669 .await?
670 .map(|value| deserialize_verification(&value))
671 .transpose()
672 }
673}
674
675static VERIFICATION_TAKE_LOCKS: LazyLock<StdMutex<HashMap<u64, Arc<Mutex<()>>>>> =
676 LazyLock::new(|| StdMutex::new(HashMap::new()));
677
678fn verification_take_lock(
679 adapter: &dyn DbAdapter,
680 stored_identifier: &str,
681) -> Result<Arc<Mutex<()>>, RustAuthError> {
682 let mut hasher = std::collections::hash_map::DefaultHasher::new();
683 (adapter as *const dyn DbAdapter).hash(&mut hasher);
684 stored_identifier.hash(&mut hasher);
685 let key = hasher.finish();
686 let mut table = VERIFICATION_TAKE_LOCKS
687 .lock()
688 .map_err(|_| RustAuthError::LockPoisoned {
689 context: "verification take lock table",
690 })?;
691 Ok(table
692 .entry(key)
693 .or_insert_with(|| Arc::new(Mutex::new(())))
694 .clone())
695}
696
697fn identifier_where(identifier: &str) -> Where {
698 Where::new("identifier", DbValue::String(identifier.to_owned()))
699}
700
701fn verification_from_record(record: DbRecord) -> Result<Verification, RustAuthError> {
702 Ok(Verification {
703 id: required_string(&record, "id")?.to_owned(),
704 identifier: required_string(&record, "identifier")?.to_owned(),
705 value: required_string(&record, "value")?.to_owned(),
706 expires_at: required_timestamp(&record, "expires_at")?,
707 created_at: required_timestamp(&record, "created_at")?,
708 updated_at: required_timestamp(&record, "updated_at")?,
709 })
710}
711
712fn verification_key(identifier: &str) -> String {
713 format!("verification:{identifier}")
714}
715
716fn serialize_verification(verification: &Verification) -> Result<String, RustAuthError> {
717 serde_json::to_string(verification).map_err(|error| RustAuthError::Serialization {
718 context: "serializing verification record",
719 message: error.to_string(),
720 })
721}
722
723fn deserialize_verification(value: &str) -> Result<Verification, RustAuthError> {
724 serde_json::from_str(value).map_err(|error| RustAuthError::Serialization {
725 context: "deserializing verification record",
726 message: error.to_string(),
727 })
728}
729
730fn ttl_seconds(expires_at: OffsetDateTime) -> Option<u64> {
731 let seconds = (expires_at - OffsetDateTime::now_utc()).whole_seconds();
732 Some(u64::try_from(seconds.max(0)).unwrap_or(0))
733}
734
735fn required_string<'a>(record: &'a DbRecord, field: &str) -> Result<&'a str, RustAuthError> {
736 match record.get(field) {
737 Some(DbValue::String(value)) => Ok(value),
738 Some(_) => Err(invalid_field(field, "string")),
739 None => Err(missing_field(field)),
740 }
741}
742
743fn required_timestamp(record: &DbRecord, field: &str) -> Result<OffsetDateTime, RustAuthError> {
744 match record.get(field) {
745 Some(DbValue::Timestamp(value)) => Ok(*value),
746 Some(_) => Err(invalid_field(field, "timestamp")),
747 None => Err(missing_field(field)),
748 }
749}
750
751fn missing_field(field: &str) -> RustAuthError {
752 RustAuthError::MissingRecordField {
753 record: "verification",
754 field: field.to_owned(),
755 }
756}
757
758fn invalid_field(field: &str, expected: &'static str) -> RustAuthError {
759 RustAuthError::InvalidRecordField {
760 record: "verification",
761 field: field.to_owned(),
762 expected,
763 }
764}