Skip to main content

rustrails_record/locking/
pessimistic.rs

1use std::{future::Future, pin::Pin};
2
3use sea_orm::{
4    ColumnTrait, ConnectionTrait, DatabaseBackend, DatabaseConnection, EntityTrait, Iterable,
5    QueryFilter, QuerySelect,
6    sea_query::{LockBehavior, LockType},
7};
8
9use crate::{
10    base::{Record, RecordError, RecordState},
11    querying::AsyncQuerying,
12    relation::resolve_column,
13};
14
15/// Boxed future returned by [`PessimisticLocking::with_lock`] callbacks.
16pub type LockFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, RecordError>> + Send + 'a>>;
17
18/// Lock clauses supported by [`PessimisticLocking`].
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
20pub enum LockOption {
21    /// `FOR UPDATE` or the nearest backend-supported equivalent.
22    #[default]
23    ForUpdate,
24    /// `FOR UPDATE NOWAIT`.
25    Nowait,
26    /// `FOR UPDATE SKIP LOCKED`.
27    SkipLocked,
28}
29
30/// Pessimistic locking support backed by row locks when the database supports them.
31///
32/// SQLite does not support `SELECT ... FOR UPDATE`, so `lock` and `lock_bang` degrade to
33/// a plain reload. `with_lock` uses `BEGIN EXCLUSIVE` when possible to emulate a write lock.
34#[allow(dead_code)]
35pub(crate) trait PessimisticLocking: Record {
36    /// Loads a record by primary key while requesting an exclusive lock when possible.
37    #[allow(private_bounds)]
38    async fn lock(id: i64, db: &DatabaseConnection) -> Result<Self, RecordError>
39    where
40        Self: Sized + AsyncQuerying,
41        <Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
42    {
43        Self::lock_with_option(id, LockOption::ForUpdate, db).await
44    }
45
46    /// Loads a record by primary key while applying the requested lock option.
47    #[allow(private_bounds)]
48    async fn lock_with_option(
49        id: i64,
50        option: LockOption,
51        db: &DatabaseConnection,
52    ) -> Result<Self, RecordError>
53    where
54        Self: Sized + AsyncQuerying,
55        <Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
56    {
57        if matches!(db.get_database_backend(), DatabaseBackend::Sqlite)
58            && matches!(option, LockOption::Nowait)
59        {
60            return Err(RecordError::Invalid(
61                "SQLite does not support NOWAIT row locks".to_owned(),
62            ));
63        }
64
65        let primary_key = resolve_column::<Self>(Self::primary_key_name())?;
66        let query = match option {
67            LockOption::ForUpdate => Self::Entity::find()
68                .filter(primary_key.eq(id))
69                .lock_exclusive(),
70            LockOption::Nowait => Self::Entity::find()
71                .filter(primary_key.eq(id))
72                .lock_with_behavior(LockType::Update, LockBehavior::Nowait),
73            LockOption::SkipLocked => Self::Entity::find()
74                .filter(primary_key.eq(id))
75                .lock_with_behavior(LockType::Update, LockBehavior::SkipLocked),
76        };
77
78        let model = query.one(db).await?.ok_or(RecordError::NotFound)?;
79        let mut record = Self::from_sea_model(model);
80        record.set_record_state(RecordState::Persisted);
81        Ok(record)
82    }
83
84    /// Reloads the record while requesting an exclusive lock when possible.
85    async fn lock_bang(&mut self, db: &DatabaseConnection) -> Result<(), RecordError>
86    where
87        Self: Sized + AsyncQuerying,
88        <Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
89    {
90        self.lock_bang_with_option(LockOption::ForUpdate, db).await
91    }
92
93    /// Reloads the record while applying the requested lock option.
94    async fn lock_bang_with_option(
95        &mut self,
96        option: LockOption,
97        db: &DatabaseConnection,
98    ) -> Result<(), RecordError>
99    where
100        Self: Sized + AsyncQuerying,
101        <Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
102    {
103        let Some(id) = self.id() else {
104            return Ok(());
105        };
106
107        *self = Self::lock_with_option(id, option, db).await?;
108        Ok(())
109    }
110
111    /// Runs the provided closure inside a transaction after reloading the record with a lock.
112    async fn with_lock<F, T>(
113        &mut self,
114        db: &DatabaseConnection,
115        option: LockOption,
116        f: F,
117    ) -> Result<T, RecordError>
118    where
119        Self: Sized + AsyncQuerying,
120        <Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
121        F: for<'a> FnOnce(&'a mut Self, &'a DatabaseConnection) -> LockFuture<'a, T> + Send,
122        T: Send,
123    {
124        let started_transaction = begin_lock_scope(db, option).await?;
125        let result = async {
126            self.lock_bang_with_option(option, db).await?;
127            f(self, db).await
128        }
129        .await;
130
131        if !started_transaction {
132            return result;
133        }
134
135        match result {
136            Ok(value) => {
137                db.execute_unprepared("COMMIT").await?;
138                Ok(value)
139            }
140            Err(error) => {
141                db.execute_unprepared("ROLLBACK").await?;
142                Err(error)
143            }
144        }
145    }
146}
147
148#[allow(dead_code)]
149async fn begin_lock_scope(
150    db: &DatabaseConnection,
151    option: LockOption,
152) -> Result<bool, RecordError> {
153    let begin_sql = match (db.get_database_backend(), option) {
154        (DatabaseBackend::Sqlite, LockOption::ForUpdate) => Some("BEGIN EXCLUSIVE"),
155        (DatabaseBackend::Sqlite, LockOption::SkipLocked) => Some("BEGIN"),
156        (DatabaseBackend::Sqlite, LockOption::Nowait) => {
157            return Err(RecordError::Invalid(
158                "SQLite does not support NOWAIT row locks".to_owned(),
159            ));
160        }
161        (_, _) => Some("BEGIN"),
162    };
163
164    let Some(begin_sql) = begin_sql else {
165        return Ok(false);
166    };
167
168    match db.execute_unprepared(begin_sql).await {
169        Ok(_) => Ok(true),
170        Err(error) if transaction_already_open(&error) => Ok(false),
171        Err(error) => Err(error.into()),
172    }
173}
174
175#[allow(dead_code)]
176fn transaction_already_open(error: &sea_orm::DbErr) -> bool {
177    let message = error.to_string().to_ascii_lowercase();
178    message.contains("within a transaction")
179        || message.contains("already a transaction")
180        || message.contains("transaction within a transaction")
181        || message.contains("cannot start a transaction")
182        || message.contains("transaction already in progress")
183}
184
185#[cfg(test)]
186mod tests {
187    use std::{
188        collections::HashMap,
189        sync::{
190            Arc,
191            atomic::{AtomicBool, Ordering},
192        },
193    };
194
195    use serde_json::json;
196
197    use super::{LockOption, PessimisticLocking};
198    use crate::{
199        Record, RecordError,
200        base::test_support::{TestUser, seed_users, setup_db},
201        persistence::AsyncPersistence,
202        querying::AsyncQuerying,
203        transactions::transaction,
204    };
205
206    impl PessimisticLocking for TestUser {}
207
208    #[tokio::test]
209    async fn lock_returns_matching_record() {
210        let db = setup_db().await;
211        seed_users(&db).await;
212
213        let user = TestUser::lock(2, &db).await.expect("row should load");
214
215        assert_eq!(user.name, "Bob");
216    }
217
218    #[tokio::test]
219    async fn lock_returns_not_found_for_missing_row() {
220        let db = setup_db().await;
221
222        let error = TestUser::lock(404, &db)
223            .await
224            .expect_err("missing row should fail");
225
226        assert!(matches!(error, crate::RecordError::NotFound));
227    }
228
229    #[tokio::test]
230    async fn lock_marks_record_as_persisted() {
231        let db = setup_db().await;
232        seed_users(&db).await;
233
234        let user = TestUser::lock(1, &db).await.expect("row should load");
235
236        assert!(user.persisted());
237    }
238
239    #[tokio::test]
240    async fn lock_can_be_called_repeatedly() {
241        let db = setup_db().await;
242        seed_users(&db).await;
243
244        let first = TestUser::lock(1, &db)
245            .await
246            .expect("first lock should work");
247        let second = TestUser::lock(1, &db)
248            .await
249            .expect("second lock should work");
250
251        assert_eq!(first.name, second.name);
252    }
253
254    #[tokio::test]
255    async fn lock_does_not_change_row_count() {
256        let db = setup_db().await;
257        seed_users(&db).await;
258
259        let _ = TestUser::lock(3, &db).await.expect("row should load");
260
261        assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 3);
262    }
263
264    #[tokio::test]
265    async fn lock_preserves_identifier_and_email() {
266        let db = setup_db().await;
267        seed_users(&db).await;
268
269        let user = TestUser::lock(2, &db).await.expect("row should load");
270
271        assert_eq!(user.id(), Some(2));
272        assert_eq!(user.email, "bob@example.com");
273    }
274
275    #[tokio::test]
276    async fn lock_with_option_skip_locked_degrades_on_sqlite() {
277        let db = setup_db().await;
278        seed_users(&db).await;
279
280        let user = TestUser::lock_with_option(1, LockOption::SkipLocked, &db)
281            .await
282            .expect("skip-locked should degrade to a plain lookup on sqlite");
283
284        assert_eq!(user.name, "Alice");
285    }
286
287    #[tokio::test]
288    async fn lock_with_option_nowait_returns_informative_error_on_sqlite() {
289        let db = setup_db().await;
290        seed_users(&db).await;
291
292        let error = TestUser::lock_with_option(1, LockOption::Nowait, &db)
293            .await
294            .expect_err("sqlite should not claim NOWAIT support");
295
296        assert!(matches!(error, RecordError::Invalid(message) if message.contains("NOWAIT")));
297    }
298
299    #[tokio::test]
300    async fn lock_with_option_returns_not_found_for_missing_row() {
301        let db = setup_db().await;
302
303        let error = TestUser::lock_with_option(99, LockOption::SkipLocked, &db)
304            .await
305            .expect_err("missing row should still be missing");
306
307        assert!(matches!(error, RecordError::NotFound));
308    }
309
310    #[tokio::test]
311    async fn lock_bang_reloads_latest_persisted_state() {
312        let db = setup_db().await;
313        seed_users(&db).await;
314
315        let mut user = TestUser::find(1, &db).await.expect("row should exist");
316        let mut other = TestUser::find(1, &db).await.expect("row should exist");
317        other.name = "Alicia".to_owned();
318        other.save(&db).await.expect("update should persist");
319
320        user.lock_bang(&db).await.expect("reload should succeed");
321
322        assert_eq!(user.name, "Alicia");
323        assert_eq!(user.email, "alice@example.com");
324    }
325
326    #[tokio::test]
327    async fn lock_bang_noops_for_new_records() {
328        let db = setup_db().await;
329        let mut user = TestUser::default();
330
331        user.lock_bang(&db).await.expect("new records should no-op");
332
333        assert!(user.new_record());
334        assert_eq!(user.id(), None);
335    }
336
337    #[tokio::test]
338    async fn with_lock_commits_changes_on_success() {
339        let db = setup_db().await;
340        seed_users(&db).await;
341        let mut user = TestUser::find(1, &db).await.expect("row should exist");
342
343        let updated_name = user
344            .with_lock(&db, LockOption::ForUpdate, |locked, txn| {
345                Box::pin(async move {
346                    locked.name = "Locked Alice".to_owned();
347                    locked.save(txn).await?;
348                    Ok(locked.name.clone())
349                })
350            })
351            .await
352            .expect("with_lock should commit successful changes");
353
354        let reloaded = TestUser::find(1, &db)
355            .await
356            .expect("row should still exist");
357        assert_eq!(updated_name, "Locked Alice");
358        assert_eq!(reloaded.name, "Locked Alice");
359    }
360
361    #[tokio::test]
362    async fn with_lock_rolls_back_changes_on_error() {
363        let db = setup_db().await;
364        seed_users(&db).await;
365        let mut user = TestUser::find(1, &db).await.expect("row should exist");
366
367        let error = user
368            .with_lock(&db, LockOption::ForUpdate, |locked, txn| {
369                Box::pin(async move {
370                    locked.name = "Should Roll Back".to_owned();
371                    locked.save(txn).await?;
372                    Err::<(), RecordError>(RecordError::Invalid("force rollback".to_owned()))
373                })
374            })
375            .await
376            .expect_err("error should trigger rollback");
377
378        assert!(matches!(error, RecordError::Invalid(message) if message == "force rollback"));
379        let reloaded = TestUser::find(1, &db)
380            .await
381            .expect("row should still exist");
382        assert_eq!(reloaded.name, "Alice");
383    }
384
385    #[tokio::test]
386    async fn with_lock_inside_existing_transaction_reuses_current_scope() {
387        let db = setup_db().await;
388        seed_users(&db).await;
389
390        transaction(&db, |txn| {
391            let txn = txn.clone();
392            Box::pin(async move {
393                let mut user = TestUser::find(2, &txn).await?;
394                user.with_lock(&txn, LockOption::ForUpdate, |locked, inner| {
395                    Box::pin(async move {
396                        locked.name = "Nested Bob".to_owned();
397                        locked.save(inner).await?;
398                        Ok(())
399                    })
400                })
401                .await?;
402                Ok(())
403            })
404        })
405        .await
406        .expect("nested with_lock should succeed");
407
408        let reloaded = TestUser::find(2, &db)
409            .await
410            .expect("row should still exist");
411        assert_eq!(reloaded.name, "Nested Bob");
412    }
413
414    #[tokio::test]
415    async fn with_lock_skip_locked_still_executes_closure_on_sqlite() {
416        let db = setup_db().await;
417        seed_users(&db).await;
418        let mut user = TestUser::find(3, &db).await.expect("row should exist");
419
420        let result = user
421            .with_lock(&db, LockOption::SkipLocked, |locked, _| {
422                Box::pin(async move {
423                    locked.name.push_str("-seen");
424                    Ok(locked.name.clone())
425                })
426            })
427            .await
428            .expect("skip-locked should still yield the record on sqlite");
429
430        assert_eq!(result, "Carol-seen");
431        assert_eq!(user.name, "Carol-seen");
432    }
433
434    #[tokio::test]
435    async fn with_lock_nowait_returns_error_before_running_closure() {
436        let db = setup_db().await;
437        seed_users(&db).await;
438        let mut user = TestUser::find(1, &db).await.expect("row should exist");
439        let ran = Arc::new(AtomicBool::new(false));
440
441        let error = user
442            .with_lock(&db, LockOption::Nowait, {
443                let ran = Arc::clone(&ran);
444                move |_locked, _| {
445                    ran.store(true, Ordering::SeqCst);
446                    Box::pin(async { Ok(()) })
447                }
448            })
449            .await
450            .expect_err("sqlite NOWAIT should fail early");
451
452        assert!(matches!(error, RecordError::Invalid(message) if message.contains("NOWAIT")));
453        assert!(!ran.load(Ordering::SeqCst));
454    }
455
456    #[tokio::test]
457    async fn lock_matches_plain_find_for_same_row() {
458        let db = setup_db().await;
459        seed_users(&db).await;
460
461        let locked = TestUser::lock(3, &db).await.expect("row should lock");
462        let found = TestUser::find(3, &db).await.expect("row should find");
463
464        assert_eq!(locked, found);
465    }
466
467    #[tokio::test]
468    async fn lock_reads_latest_persisted_values_after_update() {
469        let db = setup_db().await;
470        seed_users(&db).await;
471
472        let mut user = TestUser::lock(2, &db).await.expect("row should lock");
473        user.update_attributes(HashMap::from([("name".to_owned(), json!("Bobby"))]), &db)
474            .await
475            .expect("update should succeed");
476
477        let refreshed = TestUser::lock(2, &db)
478            .await
479            .expect("updated row should lock");
480
481        assert_eq!(refreshed.name, "Bobby");
482        assert_eq!(refreshed.email, "bob@example.com");
483        assert!(refreshed.persisted());
484    }
485}