1use sea_orm::{
2 ActiveModelBehavior, ColumnTrait, ConnectionTrait, DatabaseConnection, EntityTrait,
3 IntoActiveModel, Iterable, Statement,
4};
5use serde::{Serialize, de::DeserializeOwned};
6
7use crate::{
8 base::{Record, RecordError, RecordState},
9 persistence::AsyncPersistence,
10 querying::AsyncQuerying,
11 relation::json_to_sea_value,
12};
13
14#[allow(dead_code)]
16pub(crate) trait OptimisticLocking: Record {
17 fn lock_version(&self) -> i64;
19
20 fn increment_lock_version(&mut self);
22
23 fn locking_column() -> &'static str {
25 "lock_version"
26 }
27
28 fn lock_optimistically() -> bool {
30 true
31 }
32
33 fn locking_enabled() -> bool {
35 Self::lock_optimistically()
36 }
37
38 fn check_lock_version(&self, expected: i64) -> Result<(), StaleObjectError> {
40 if self.lock_version() == expected {
41 Ok(())
42 } else {
43 Err(StaleObjectError)
44 }
45 }
46
47 async fn current_lock_version(&self, db: &DatabaseConnection) -> Result<i64, RecordError>
49 where
50 Self: Sized + AsyncQuerying,
51 <Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
52 {
53 let id = self.id().ok_or(RecordError::NotSaved)?;
54 let fresh = Self::find(id, db).await?;
55 Ok(fresh.lock_version())
56 }
57
58 async fn save_with_optimistic_lock(
60 &mut self,
61 db: &DatabaseConnection,
62 ) -> Result<(), RecordError>
63 where
64 Self: Sized + AsyncPersistence + AsyncQuerying + Serialize + DeserializeOwned,
65 <Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
66 <Self::Entity as EntityTrait>::Model:
67 IntoActiveModel<<Self::Entity as EntityTrait>::ActiveModel>,
68 <Self::Entity as EntityTrait>::ActiveModel: ActiveModelBehavior + Send,
69 {
70 if self.destroyed() {
71 return Err(RecordError::NotSaved);
72 }
73
74 if !Self::locking_enabled() || self.new_record() {
75 return AsyncPersistence::save(self, db).await;
76 }
77
78 let id = self.id().ok_or(RecordError::NotSaved)?;
79 let expected = self.lock_version();
80 let current = self.current_lock_version(db).await?;
81 if current != expected {
82 return Err(StaleObjectError.into());
83 }
84
85 let assignments =
86 serialized_assignments(self, &[Self::primary_key_name(), Self::locking_column()])?;
87 let sql = build_update_sql::<Self>(&assignments);
88 let mut values = assignments
89 .into_iter()
90 .map(|(_, value)| value)
91 .collect::<Vec<_>>();
92 values.push(id.into());
93 values.push(expected.into());
94
95 let result = db
96 .execute_raw(Statement::from_sql_and_values(
97 db.get_database_backend(),
98 sql,
99 values,
100 ))
101 .await?;
102
103 if result.rows_affected() == 0 {
104 return Err(StaleObjectError.into());
105 }
106
107 self.increment_lock_version();
108 self.set_record_state(RecordState::Persisted);
109 Ok(())
110 }
111
112 async fn destroy_with_optimistic_lock(
114 &mut self,
115 db: &DatabaseConnection,
116 ) -> Result<(), RecordError>
117 where
118 Self: Sized + AsyncPersistence,
119 <Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
120 {
121 if self.destroyed() {
122 return Err(RecordError::NotSaved);
123 }
124
125 if !Self::locking_enabled() {
126 return AsyncPersistence::destroy(self, db).await;
127 }
128
129 let id = self.id().ok_or(RecordError::NotSaved)?;
130 let result = db
131 .execute_raw(Statement::from_sql_and_values(
132 db.get_database_backend(),
133 format!(
134 "DELETE FROM {table} WHERE {primary_key} = ? AND {locking_column} = ?",
135 table = Self::table_name(),
136 primary_key = Self::primary_key_name(),
137 locking_column = Self::locking_column(),
138 ),
139 [id.into(), self.lock_version().into()],
140 ))
141 .await?;
142
143 if result.rows_affected() == 0 {
144 return Err(StaleObjectError.into());
145 }
146
147 self.set_record_state(RecordState::Destroyed);
148 Ok(())
149 }
150}
151
152#[allow(dead_code)]
153fn serialized_assignments<T: Serialize>(
154 record: &T,
155 excluded_columns: &[&str],
156) -> Result<Vec<(String, sea_orm::Value)>, RecordError> {
157 let json =
158 serde_json::to_value(record).map_err(|error| RecordError::Invalid(error.to_string()))?;
159 let object = json
160 .as_object()
161 .ok_or_else(|| RecordError::Invalid("record must serialize to a JSON object".to_owned()))?;
162
163 let mut assignments = Vec::new();
164 for (column, value) in object {
165 if excluded_columns
166 .iter()
167 .any(|excluded| excluded == &column.as_str())
168 {
169 continue;
170 }
171
172 assignments.push((column.clone(), json_to_sea_value(value)?));
173 }
174
175 Ok(assignments)
176}
177
178#[allow(dead_code)]
179fn build_update_sql<T: OptimisticLocking>(assignments: &[(String, sea_orm::Value)]) -> String {
180 let mut sql = format!("UPDATE {} SET ", T::table_name());
181 if assignments.is_empty() {
182 sql.push_str(&format!(
183 "{locking_column} = {locking_column} + 1",
184 locking_column = T::locking_column(),
185 ));
186 } else {
187 for (index, (column, _)) in assignments.iter().enumerate() {
188 if index > 0 {
189 sql.push_str(", ");
190 }
191 sql.push_str(column);
192 sql.push_str(" = ?");
193 }
194 sql.push_str(&format!(
195 ", {locking_column} = {locking_column} + 1",
196 locking_column = T::locking_column(),
197 ));
198 }
199
200 sql.push_str(&format!(
201 " WHERE {primary_key} = ? AND {locking_column} = ?",
202 primary_key = T::primary_key_name(),
203 locking_column = T::locking_column(),
204 ));
205 sql
206}
207
208#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
210#[error("stale object error: record was updated by another process")]
211pub struct StaleObjectError;
212
213impl From<StaleObjectError> for RecordError {
214 fn from(_: StaleObjectError) -> Self {
215 RecordError::StaleObject
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use sea_orm::{
222 ActiveValue::{NotSet, Set},
223 ConnectionTrait, Database, DatabaseConnection, EntityTrait, Statement,
224 };
225 use serde::{Deserialize, Serialize};
226 use serde_json::json;
227 use std::collections::HashMap;
228
229 use super::{OptimisticLocking, StaleObjectError};
230 use crate::{
231 base::{Record, RecordError, RecordState},
232 persistence::AsyncPersistence,
233 querying::AsyncQuerying,
234 };
235
236 mod versioned_user {
237 use sea_orm::entity::prelude::*;
238
239 #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
240 #[sea_orm(table_name = "versioned_users")]
241 pub struct Model {
242 #[sea_orm(primary_key)]
243 pub id: i32,
244 pub name: String,
245 pub email: String,
246 pub lock_version: i64,
247 }
248
249 #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
250 pub enum Relation {}
251
252 impl ActiveModelBehavior for ActiveModel {}
253 }
254
255 fn default_record_state() -> RecordState {
256 RecordState::New
257 }
258
259 #[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
260 #[serde(deny_unknown_fields)]
261 struct VersionedUser {
262 id: Option<i64>,
263 name: String,
264 email: String,
265 lock_version: i64,
266 #[serde(skip, default = "default_record_state")]
267 state: RecordState,
268 }
269
270 #[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
271 #[serde(deny_unknown_fields)]
272 struct UnlockedVersionedUser {
273 id: Option<i64>,
274 name: String,
275 email: String,
276 lock_version: i64,
277 #[serde(skip, default = "default_record_state")]
278 state: RecordState,
279 }
280
281 macro_rules! impl_versioned_record {
282 ($name:ident, $lock_optimistically:expr) => {
283 impl Record for $name {
284 type Entity = versioned_user::Entity;
285
286 fn table_name() -> &'static str {
287 "versioned_users"
288 }
289
290 fn id(&self) -> Option<i64> {
291 self.id
292 }
293
294 fn record_state(&self) -> RecordState {
295 self.state
296 }
297
298 fn set_record_state(&mut self, state: RecordState) {
299 self.state = state;
300 }
301
302 fn from_sea_model(model: <Self::Entity as EntityTrait>::Model) -> Self {
303 Self {
304 id: Some(i64::from(model.id)),
305 name: model.name,
306 email: model.email,
307 lock_version: model.lock_version,
308 state: RecordState::Persisted,
309 }
310 }
311
312 fn to_active_model(&self) -> <Self::Entity as EntityTrait>::ActiveModel {
313 versioned_user::ActiveModel {
314 id: match self.id.and_then(|value| i32::try_from(value).ok()) {
315 Some(value) => Set(value),
316 None => NotSet,
317 },
318 name: Set(self.name.clone()),
319 email: Set(self.email.clone()),
320 lock_version: Set(self.lock_version),
321 }
322 }
323 }
324
325 impl AsyncPersistence for $name {}
326 impl AsyncQuerying for $name {}
327
328 impl OptimisticLocking for $name {
329 fn lock_version(&self) -> i64 {
330 self.lock_version
331 }
332
333 fn increment_lock_version(&mut self) {
334 self.lock_version += 1;
335 }
336
337 fn lock_optimistically() -> bool {
338 $lock_optimistically
339 }
340 }
341 };
342 }
343
344 impl_versioned_record!(VersionedUser, true);
345 impl_versioned_record!(UnlockedVersionedUser, false);
346
347 async fn setup_db() -> DatabaseConnection {
348 let db = Database::connect("sqlite::memory:")
349 .await
350 .expect("in-memory sqlite connection should succeed");
351 let schema = sea_orm::Schema::new(db.get_database_backend());
352 db.execute(&schema.create_table_from_entity(versioned_user::Entity))
353 .await
354 .expect("versioned_users table should be created");
355 db
356 }
357
358 async fn create_user(db: &DatabaseConnection, name: &str, email: &str) -> VersionedUser {
359 VersionedUser::create(
360 HashMap::from([
361 ("name".to_owned(), json!(name)),
362 ("email".to_owned(), json!(email)),
363 ]),
364 db,
365 )
366 .await
367 .expect("create should succeed")
368 }
369
370 async fn db_lock_version(db: &DatabaseConnection, id: i64) -> i64 {
371 db.query_one_raw(Statement::from_sql_and_values(
372 db.get_database_backend(),
373 "SELECT lock_version FROM versioned_users WHERE id = ?".to_owned(),
374 [id.into()],
375 ))
376 .await
377 .expect("lock-version query should succeed")
378 .expect("row should exist")
379 .try_get("", "lock_version")
380 .expect("lock_version should be readable")
381 }
382
383 #[test]
384 fn locking_enabled_defaults_to_true() {
385 assert!(VersionedUser::locking_enabled());
386 assert!(VersionedUser::lock_optimistically());
387 }
388
389 #[test]
390 fn lock_optimistically_can_disable_locking() {
391 assert!(!UnlockedVersionedUser::lock_optimistically());
392 assert!(!UnlockedVersionedUser::locking_enabled());
393 }
394
395 #[test]
396 fn stale_object_error_display_is_descriptive() {
397 assert_eq!(
398 StaleObjectError.to_string(),
399 "stale object error: record was updated by another process"
400 );
401 }
402
403 #[test]
404 fn stale_object_error_converts_to_record_error() {
405 let error: RecordError = StaleObjectError.into();
406 assert!(matches!(error, RecordError::StaleObject));
407 }
408
409 #[test]
410 fn check_lock_version_accepts_matching_version() {
411 let user = VersionedUser {
412 lock_version: 3,
413 ..Default::default()
414 };
415
416 user.check_lock_version(3)
417 .expect("matching version should succeed");
418 }
419
420 #[test]
421 fn check_lock_version_rejects_stale_version() {
422 let user = VersionedUser {
423 lock_version: 4,
424 ..Default::default()
425 };
426
427 let error = user
428 .check_lock_version(3)
429 .expect_err("mismatched version should fail");
430 assert_eq!(error, StaleObjectError);
431 }
432
433 #[test]
434 fn increment_lock_version_updates_counter() {
435 let mut user = VersionedUser::default();
436
437 user.increment_lock_version();
438 user.increment_lock_version();
439
440 assert_eq!(user.lock_version(), 2);
441 }
442
443 #[tokio::test]
444 async fn current_lock_version_reads_database_value() {
445 let db = setup_db().await;
446 let user = create_user(&db, "Alice", "alice@example.com").await;
447
448 assert_eq!(user.current_lock_version(&db).await.unwrap(), 0);
449 }
450
451 #[tokio::test]
452 async fn save_with_optimistic_lock_inserts_new_record_at_version_zero() {
453 let db = setup_db().await;
454 let mut user = VersionedUser {
455 name: "Alice".to_owned(),
456 email: "alice@example.com".to_owned(),
457 ..Default::default()
458 };
459
460 user.save_with_optimistic_lock(&db)
461 .await
462 .expect("insert should succeed");
463
464 assert!(user.persisted());
465 assert_eq!(user.lock_version(), 0);
466 assert_eq!(db_lock_version(&db, user.id().unwrap()).await, 0);
467 }
468
469 #[tokio::test]
470 async fn save_with_optimistic_lock_updates_row_and_bumps_version() {
471 let db = setup_db().await;
472 let mut user = create_user(&db, "Alice", "alice@example.com").await;
473 user.name = "Alicia".to_owned();
474
475 user.save_with_optimistic_lock(&db)
476 .await
477 .expect("optimistic update should succeed");
478
479 let reloaded = VersionedUser::find(user.id().unwrap(), &db)
480 .await
481 .expect("reloaded row should exist");
482 assert_eq!(user.lock_version(), 1);
483 assert_eq!(reloaded.name, "Alicia");
484 assert_eq!(reloaded.lock_version(), 1);
485 }
486
487 #[tokio::test]
488 async fn save_with_optimistic_lock_rejects_stale_version() {
489 let db = setup_db().await;
490 let mut fresh = create_user(&db, "Alice", "alice@example.com").await;
491 let mut stale = VersionedUser::find(fresh.id().unwrap(), &db)
492 .await
493 .expect("second copy should exist");
494
495 fresh.name = "Alicia".to_owned();
496 fresh
497 .save_with_optimistic_lock(&db)
498 .await
499 .expect("first update should succeed");
500 stale.name = "Outdated".to_owned();
501
502 let error = stale
503 .save_with_optimistic_lock(&db)
504 .await
505 .expect_err("stale update should fail");
506
507 assert!(matches!(error, RecordError::StaleObject));
508 assert_eq!(stale.lock_version(), 0);
509 assert_eq!(db_lock_version(&db, fresh.id().unwrap()).await, 1);
510 }
511
512 #[tokio::test]
513 async fn save_with_optimistic_lock_detects_external_version_bump() {
514 let db = setup_db().await;
515 let mut user = create_user(&db, "Alice", "alice@example.com").await;
516
517 db.execute_unprepared("UPDATE versioned_users SET lock_version = 2 WHERE id = 1")
518 .await
519 .expect("direct version bump should succeed");
520
521 let error = user
522 .save_with_optimistic_lock(&db)
523 .await
524 .expect_err("external version bump should cause stale error");
525 assert!(matches!(error, RecordError::StaleObject));
526 assert_eq!(user.lock_version(), 0);
527 }
528
529 #[tokio::test]
530 async fn save_with_optimistic_lock_returns_not_saved_for_destroyed_records() {
531 let db = setup_db().await;
532 let mut user = create_user(&db, "Alice", "alice@example.com").await;
533 user.set_record_state(RecordState::Destroyed);
534
535 let error = user
536 .save_with_optimistic_lock(&db)
537 .await
538 .expect_err("destroyed records cannot be saved");
539 assert!(matches!(error, RecordError::NotSaved));
540 }
541
542 #[tokio::test]
543 async fn destroy_with_optimistic_lock_deletes_matching_version() {
544 let db = setup_db().await;
545 let mut user = create_user(&db, "Alice", "alice@example.com").await;
546
547 user.destroy_with_optimistic_lock(&db)
548 .await
549 .expect("destroy should succeed");
550
551 assert!(user.destroyed());
552 assert_eq!(VersionedUser::count(&db).await.unwrap(), 0);
553 }
554
555 #[tokio::test]
556 async fn destroy_with_optimistic_lock_rejects_stale_record() {
557 let db = setup_db().await;
558 let fresh = create_user(&db, "Alice", "alice@example.com").await;
559 let mut stale = VersionedUser::find(fresh.id().unwrap(), &db)
560 .await
561 .expect("stale copy should exist");
562
563 db.execute_unprepared("UPDATE versioned_users SET lock_version = 1 WHERE id = 1")
564 .await
565 .expect("direct version bump should succeed");
566
567 let error = stale
568 .destroy_with_optimistic_lock(&db)
569 .await
570 .expect_err("stale destroy should fail");
571 assert!(matches!(error, RecordError::StaleObject));
572 assert!(stale.persisted());
573 assert_eq!(VersionedUser::count(&db).await.unwrap(), 1);
574 }
575
576 #[tokio::test]
577 async fn destroy_with_optimistic_lock_returns_not_saved_without_id() {
578 let db = setup_db().await;
579 let mut user = VersionedUser::default();
580
581 let error = user
582 .destroy_with_optimistic_lock(&db)
583 .await
584 .expect_err("new records cannot be destroyed optimistically");
585 assert!(matches!(error, RecordError::NotSaved));
586 }
587
588 #[tokio::test]
589 async fn save_with_optimistic_lock_uses_regular_save_when_disabled() {
590 let db = setup_db().await;
591 let mut user = UnlockedVersionedUser {
592 name: "Alice".to_owned(),
593 email: "alice@example.com".to_owned(),
594 ..Default::default()
595 };
596
597 user.save_with_optimistic_lock(&db)
598 .await
599 .expect("save should delegate when locking is disabled");
600 user.name = "Alicia".to_owned();
601 user.save_with_optimistic_lock(&db)
602 .await
603 .expect("update should delegate when locking is disabled");
604
605 let reloaded = UnlockedVersionedUser::find(user.id().unwrap(), &db)
606 .await
607 .expect("row should still reload");
608 assert_eq!(reloaded.name, "Alicia");
609 assert_eq!(reloaded.lock_version(), 0);
610 }
611
612 #[tokio::test]
613 async fn destroy_with_optimistic_lock_uses_regular_destroy_when_disabled() {
614 let db = setup_db().await;
615 let mut user = UnlockedVersionedUser::create(
616 HashMap::from([
617 ("name".to_owned(), json!("Alice")),
618 ("email".to_owned(), json!("alice@example.com")),
619 ]),
620 &db,
621 )
622 .await
623 .expect("create should succeed");
624
625 user.destroy_with_optimistic_lock(&db)
626 .await
627 .expect("destroy should delegate when locking is disabled");
628
629 assert!(user.destroyed());
630 assert_eq!(UnlockedVersionedUser::count(&db).await.unwrap(), 0);
631 }
632
633 #[tokio::test]
634 async fn multiple_successive_optimistic_saves_keep_advancing_version() {
635 let db = setup_db().await;
636 let mut user = create_user(&db, "Alice", "alice@example.com").await;
637
638 user.name = "Alicia".to_owned();
639 user.save_with_optimistic_lock(&db).await.unwrap();
640 user.email = "alicia@example.com".to_owned();
641 user.save_with_optimistic_lock(&db).await.unwrap();
642
643 assert_eq!(user.lock_version(), 2);
644 let reloaded = VersionedUser::find(user.id().unwrap(), &db).await.unwrap();
645 assert_eq!(reloaded.name, "Alicia");
646 assert_eq!(reloaded.email, "alicia@example.com");
647 assert_eq!(reloaded.lock_version(), 2);
648 }
649}