1use std::{future::Future, pin::Pin};
2
3use sea_orm::{
4 ColumnTrait, ConnectionTrait, DatabaseBackend, DatabaseConnection, EntityTrait, Iterable,
5 QueryFilter, QuerySelect,
6 sea_query::{LockBehavior, LockType},
7};
8
9use crate::{
10 base::{Record, RecordError, RecordState},
11 querying::AsyncQuerying,
12 relation::resolve_column,
13};
14
15pub type LockFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, RecordError>> + Send + 'a>>;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
20pub enum LockOption {
21 #[default]
23 ForUpdate,
24 Nowait,
26 SkipLocked,
28}
29
30#[allow(dead_code)]
35pub(crate) trait PessimisticLocking: Record {
36 #[allow(private_bounds)]
38 async fn lock(id: i64, db: &DatabaseConnection) -> Result<Self, RecordError>
39 where
40 Self: Sized + AsyncQuerying,
41 <Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
42 {
43 Self::lock_with_option(id, LockOption::ForUpdate, db).await
44 }
45
46 #[allow(private_bounds)]
48 async fn lock_with_option(
49 id: i64,
50 option: LockOption,
51 db: &DatabaseConnection,
52 ) -> Result<Self, RecordError>
53 where
54 Self: Sized + AsyncQuerying,
55 <Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
56 {
57 if matches!(db.get_database_backend(), DatabaseBackend::Sqlite)
58 && matches!(option, LockOption::Nowait)
59 {
60 return Err(RecordError::Invalid(
61 "SQLite does not support NOWAIT row locks".to_owned(),
62 ));
63 }
64
65 let primary_key = resolve_column::<Self>(Self::primary_key_name())?;
66 let query = match option {
67 LockOption::ForUpdate => Self::Entity::find()
68 .filter(primary_key.eq(id))
69 .lock_exclusive(),
70 LockOption::Nowait => Self::Entity::find()
71 .filter(primary_key.eq(id))
72 .lock_with_behavior(LockType::Update, LockBehavior::Nowait),
73 LockOption::SkipLocked => Self::Entity::find()
74 .filter(primary_key.eq(id))
75 .lock_with_behavior(LockType::Update, LockBehavior::SkipLocked),
76 };
77
78 let model = query.one(db).await?.ok_or(RecordError::NotFound)?;
79 let mut record = Self::from_sea_model(model);
80 record.set_record_state(RecordState::Persisted);
81 Ok(record)
82 }
83
84 async fn lock_bang(&mut self, db: &DatabaseConnection) -> Result<(), RecordError>
86 where
87 Self: Sized + AsyncQuerying,
88 <Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
89 {
90 self.lock_bang_with_option(LockOption::ForUpdate, db).await
91 }
92
93 async fn lock_bang_with_option(
95 &mut self,
96 option: LockOption,
97 db: &DatabaseConnection,
98 ) -> Result<(), RecordError>
99 where
100 Self: Sized + AsyncQuerying,
101 <Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
102 {
103 let Some(id) = self.id() else {
104 return Ok(());
105 };
106
107 *self = Self::lock_with_option(id, option, db).await?;
108 Ok(())
109 }
110
111 async fn with_lock<F, T>(
113 &mut self,
114 db: &DatabaseConnection,
115 option: LockOption,
116 f: F,
117 ) -> Result<T, RecordError>
118 where
119 Self: Sized + AsyncQuerying,
120 <Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
121 F: for<'a> FnOnce(&'a mut Self, &'a DatabaseConnection) -> LockFuture<'a, T> + Send,
122 T: Send,
123 {
124 let started_transaction = begin_lock_scope(db, option).await?;
125 let result = async {
126 self.lock_bang_with_option(option, db).await?;
127 f(self, db).await
128 }
129 .await;
130
131 if !started_transaction {
132 return result;
133 }
134
135 match result {
136 Ok(value) => {
137 db.execute_unprepared("COMMIT").await?;
138 Ok(value)
139 }
140 Err(error) => {
141 db.execute_unprepared("ROLLBACK").await?;
142 Err(error)
143 }
144 }
145 }
146}
147
148#[allow(dead_code)]
149async fn begin_lock_scope(
150 db: &DatabaseConnection,
151 option: LockOption,
152) -> Result<bool, RecordError> {
153 let begin_sql = match (db.get_database_backend(), option) {
154 (DatabaseBackend::Sqlite, LockOption::ForUpdate) => Some("BEGIN EXCLUSIVE"),
155 (DatabaseBackend::Sqlite, LockOption::SkipLocked) => Some("BEGIN"),
156 (DatabaseBackend::Sqlite, LockOption::Nowait) => {
157 return Err(RecordError::Invalid(
158 "SQLite does not support NOWAIT row locks".to_owned(),
159 ));
160 }
161 (_, _) => Some("BEGIN"),
162 };
163
164 let Some(begin_sql) = begin_sql else {
165 return Ok(false);
166 };
167
168 match db.execute_unprepared(begin_sql).await {
169 Ok(_) => Ok(true),
170 Err(error) if transaction_already_open(&error) => Ok(false),
171 Err(error) => Err(error.into()),
172 }
173}
174
175#[allow(dead_code)]
176fn transaction_already_open(error: &sea_orm::DbErr) -> bool {
177 let message = error.to_string().to_ascii_lowercase();
178 message.contains("within a transaction")
179 || message.contains("already a transaction")
180 || message.contains("transaction within a transaction")
181 || message.contains("cannot start a transaction")
182 || message.contains("transaction already in progress")
183}
184
185#[cfg(test)]
186mod tests {
187 use std::{
188 collections::HashMap,
189 sync::{
190 Arc,
191 atomic::{AtomicBool, Ordering},
192 },
193 };
194
195 use serde_json::json;
196
197 use super::{LockOption, PessimisticLocking};
198 use crate::{
199 Record, RecordError,
200 base::test_support::{TestUser, seed_users, setup_db},
201 persistence::AsyncPersistence,
202 querying::AsyncQuerying,
203 transactions::transaction,
204 };
205
206 impl PessimisticLocking for TestUser {}
207
208 #[tokio::test]
209 async fn lock_returns_matching_record() {
210 let db = setup_db().await;
211 seed_users(&db).await;
212
213 let user = TestUser::lock(2, &db).await.expect("row should load");
214
215 assert_eq!(user.name, "Bob");
216 }
217
218 #[tokio::test]
219 async fn lock_returns_not_found_for_missing_row() {
220 let db = setup_db().await;
221
222 let error = TestUser::lock(404, &db)
223 .await
224 .expect_err("missing row should fail");
225
226 assert!(matches!(error, crate::RecordError::NotFound));
227 }
228
229 #[tokio::test]
230 async fn lock_marks_record_as_persisted() {
231 let db = setup_db().await;
232 seed_users(&db).await;
233
234 let user = TestUser::lock(1, &db).await.expect("row should load");
235
236 assert!(user.persisted());
237 }
238
239 #[tokio::test]
240 async fn lock_can_be_called_repeatedly() {
241 let db = setup_db().await;
242 seed_users(&db).await;
243
244 let first = TestUser::lock(1, &db)
245 .await
246 .expect("first lock should work");
247 let second = TestUser::lock(1, &db)
248 .await
249 .expect("second lock should work");
250
251 assert_eq!(first.name, second.name);
252 }
253
254 #[tokio::test]
255 async fn lock_does_not_change_row_count() {
256 let db = setup_db().await;
257 seed_users(&db).await;
258
259 let _ = TestUser::lock(3, &db).await.expect("row should load");
260
261 assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 3);
262 }
263
264 #[tokio::test]
265 async fn lock_preserves_identifier_and_email() {
266 let db = setup_db().await;
267 seed_users(&db).await;
268
269 let user = TestUser::lock(2, &db).await.expect("row should load");
270
271 assert_eq!(user.id(), Some(2));
272 assert_eq!(user.email, "bob@example.com");
273 }
274
275 #[tokio::test]
276 async fn lock_with_option_skip_locked_degrades_on_sqlite() {
277 let db = setup_db().await;
278 seed_users(&db).await;
279
280 let user = TestUser::lock_with_option(1, LockOption::SkipLocked, &db)
281 .await
282 .expect("skip-locked should degrade to a plain lookup on sqlite");
283
284 assert_eq!(user.name, "Alice");
285 }
286
287 #[tokio::test]
288 async fn lock_with_option_nowait_returns_informative_error_on_sqlite() {
289 let db = setup_db().await;
290 seed_users(&db).await;
291
292 let error = TestUser::lock_with_option(1, LockOption::Nowait, &db)
293 .await
294 .expect_err("sqlite should not claim NOWAIT support");
295
296 assert!(matches!(error, RecordError::Invalid(message) if message.contains("NOWAIT")));
297 }
298
299 #[tokio::test]
300 async fn lock_with_option_returns_not_found_for_missing_row() {
301 let db = setup_db().await;
302
303 let error = TestUser::lock_with_option(99, LockOption::SkipLocked, &db)
304 .await
305 .expect_err("missing row should still be missing");
306
307 assert!(matches!(error, RecordError::NotFound));
308 }
309
310 #[tokio::test]
311 async fn lock_bang_reloads_latest_persisted_state() {
312 let db = setup_db().await;
313 seed_users(&db).await;
314
315 let mut user = TestUser::find(1, &db).await.expect("row should exist");
316 let mut other = TestUser::find(1, &db).await.expect("row should exist");
317 other.name = "Alicia".to_owned();
318 other.save(&db).await.expect("update should persist");
319
320 user.lock_bang(&db).await.expect("reload should succeed");
321
322 assert_eq!(user.name, "Alicia");
323 assert_eq!(user.email, "alice@example.com");
324 }
325
326 #[tokio::test]
327 async fn lock_bang_noops_for_new_records() {
328 let db = setup_db().await;
329 let mut user = TestUser::default();
330
331 user.lock_bang(&db).await.expect("new records should no-op");
332
333 assert!(user.new_record());
334 assert_eq!(user.id(), None);
335 }
336
337 #[tokio::test]
338 async fn with_lock_commits_changes_on_success() {
339 let db = setup_db().await;
340 seed_users(&db).await;
341 let mut user = TestUser::find(1, &db).await.expect("row should exist");
342
343 let updated_name = user
344 .with_lock(&db, LockOption::ForUpdate, |locked, txn| {
345 Box::pin(async move {
346 locked.name = "Locked Alice".to_owned();
347 locked.save(txn).await?;
348 Ok(locked.name.clone())
349 })
350 })
351 .await
352 .expect("with_lock should commit successful changes");
353
354 let reloaded = TestUser::find(1, &db)
355 .await
356 .expect("row should still exist");
357 assert_eq!(updated_name, "Locked Alice");
358 assert_eq!(reloaded.name, "Locked Alice");
359 }
360
361 #[tokio::test]
362 async fn with_lock_rolls_back_changes_on_error() {
363 let db = setup_db().await;
364 seed_users(&db).await;
365 let mut user = TestUser::find(1, &db).await.expect("row should exist");
366
367 let error = user
368 .with_lock(&db, LockOption::ForUpdate, |locked, txn| {
369 Box::pin(async move {
370 locked.name = "Should Roll Back".to_owned();
371 locked.save(txn).await?;
372 Err::<(), RecordError>(RecordError::Invalid("force rollback".to_owned()))
373 })
374 })
375 .await
376 .expect_err("error should trigger rollback");
377
378 assert!(matches!(error, RecordError::Invalid(message) if message == "force rollback"));
379 let reloaded = TestUser::find(1, &db)
380 .await
381 .expect("row should still exist");
382 assert_eq!(reloaded.name, "Alice");
383 }
384
385 #[tokio::test]
386 async fn with_lock_inside_existing_transaction_reuses_current_scope() {
387 let db = setup_db().await;
388 seed_users(&db).await;
389
390 transaction(&db, |txn| {
391 let txn = txn.clone();
392 Box::pin(async move {
393 let mut user = TestUser::find(2, &txn).await?;
394 user.with_lock(&txn, LockOption::ForUpdate, |locked, inner| {
395 Box::pin(async move {
396 locked.name = "Nested Bob".to_owned();
397 locked.save(inner).await?;
398 Ok(())
399 })
400 })
401 .await?;
402 Ok(())
403 })
404 })
405 .await
406 .expect("nested with_lock should succeed");
407
408 let reloaded = TestUser::find(2, &db)
409 .await
410 .expect("row should still exist");
411 assert_eq!(reloaded.name, "Nested Bob");
412 }
413
414 #[tokio::test]
415 async fn with_lock_skip_locked_still_executes_closure_on_sqlite() {
416 let db = setup_db().await;
417 seed_users(&db).await;
418 let mut user = TestUser::find(3, &db).await.expect("row should exist");
419
420 let result = user
421 .with_lock(&db, LockOption::SkipLocked, |locked, _| {
422 Box::pin(async move {
423 locked.name.push_str("-seen");
424 Ok(locked.name.clone())
425 })
426 })
427 .await
428 .expect("skip-locked should still yield the record on sqlite");
429
430 assert_eq!(result, "Carol-seen");
431 assert_eq!(user.name, "Carol-seen");
432 }
433
434 #[tokio::test]
435 async fn with_lock_nowait_returns_error_before_running_closure() {
436 let db = setup_db().await;
437 seed_users(&db).await;
438 let mut user = TestUser::find(1, &db).await.expect("row should exist");
439 let ran = Arc::new(AtomicBool::new(false));
440
441 let error = user
442 .with_lock(&db, LockOption::Nowait, {
443 let ran = Arc::clone(&ran);
444 move |_locked, _| {
445 ran.store(true, Ordering::SeqCst);
446 Box::pin(async { Ok(()) })
447 }
448 })
449 .await
450 .expect_err("sqlite NOWAIT should fail early");
451
452 assert!(matches!(error, RecordError::Invalid(message) if message.contains("NOWAIT")));
453 assert!(!ran.load(Ordering::SeqCst));
454 }
455
456 #[tokio::test]
457 async fn lock_matches_plain_find_for_same_row() {
458 let db = setup_db().await;
459 seed_users(&db).await;
460
461 let locked = TestUser::lock(3, &db).await.expect("row should lock");
462 let found = TestUser::find(3, &db).await.expect("row should find");
463
464 assert_eq!(locked, found);
465 }
466
467 #[tokio::test]
468 async fn lock_reads_latest_persisted_values_after_update() {
469 let db = setup_db().await;
470 seed_users(&db).await;
471
472 let mut user = TestUser::lock(2, &db).await.expect("row should lock");
473 user.update_attributes(HashMap::from([("name".to_owned(), json!("Bobby"))]), &db)
474 .await
475 .expect("update should succeed");
476
477 let refreshed = TestUser::lock(2, &db)
478 .await
479 .expect("updated row should lock");
480
481 assert_eq!(refreshed.name, "Bobby");
482 assert_eq!(refreshed.email, "bob@example.com");
483 assert!(refreshed.persisted());
484 }
485}