Skip to main content

rustrails_record/locking/
optimistic.rs

1use sea_orm::{
2    ActiveModelBehavior, ColumnTrait, ConnectionTrait, DatabaseConnection, EntityTrait,
3    IntoActiveModel, Iterable, Statement,
4};
5use serde::{Serialize, de::DeserializeOwned};
6
7use crate::{
8    base::{Record, RecordError, RecordState},
9    persistence::AsyncPersistence,
10    querying::AsyncQuerying,
11    relation::json_to_sea_value,
12};
13
14/// Optimistic locking support backed by a `lock_version`-style column.
15#[allow(dead_code)]
16pub(crate) trait OptimisticLocking: Record {
17    /// Returns the current lock version.
18    fn lock_version(&self) -> i64;
19
20    /// Increments the in-memory lock version after a successful update.
21    fn increment_lock_version(&mut self);
22
23    /// The optimistic locking column name.
24    fn locking_column() -> &'static str {
25        "lock_version"
26    }
27
28    /// Whether the model class opts into optimistic locking.
29    fn lock_optimistically() -> bool {
30        true
31    }
32
33    /// Whether optimistic locking is active for this record type.
34    fn locking_enabled() -> bool {
35        Self::lock_optimistically()
36    }
37
38    /// Verifies the current lock version against the expected version.
39    fn check_lock_version(&self, expected: i64) -> Result<(), StaleObjectError> {
40        if self.lock_version() == expected {
41            Ok(())
42        } else {
43            Err(StaleObjectError)
44        }
45    }
46
47    /// Loads the persisted lock version for this record.
48    async fn current_lock_version(&self, db: &DatabaseConnection) -> Result<i64, RecordError>
49    where
50        Self: Sized + AsyncQuerying,
51        <Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
52    {
53        let id = self.id().ok_or(RecordError::NotSaved)?;
54        let fresh = Self::find(id, db).await?;
55        Ok(fresh.lock_version())
56    }
57
58    /// Saves the record while enforcing the current optimistic lock version.
59    async fn save_with_optimistic_lock(
60        &mut self,
61        db: &DatabaseConnection,
62    ) -> Result<(), RecordError>
63    where
64        Self: Sized + AsyncPersistence + AsyncQuerying + Serialize + DeserializeOwned,
65        <Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
66        <Self::Entity as EntityTrait>::Model:
67            IntoActiveModel<<Self::Entity as EntityTrait>::ActiveModel>,
68        <Self::Entity as EntityTrait>::ActiveModel: ActiveModelBehavior + Send,
69    {
70        if self.destroyed() {
71            return Err(RecordError::NotSaved);
72        }
73
74        if !Self::locking_enabled() || self.new_record() {
75            return AsyncPersistence::save(self, db).await;
76        }
77
78        let id = self.id().ok_or(RecordError::NotSaved)?;
79        let expected = self.lock_version();
80        let current = self.current_lock_version(db).await?;
81        if current != expected {
82            return Err(StaleObjectError.into());
83        }
84
85        let assignments =
86            serialized_assignments(self, &[Self::primary_key_name(), Self::locking_column()])?;
87        let sql = build_update_sql::<Self>(&assignments);
88        let mut values = assignments
89            .into_iter()
90            .map(|(_, value)| value)
91            .collect::<Vec<_>>();
92        values.push(id.into());
93        values.push(expected.into());
94
95        let result = db
96            .execute_raw(Statement::from_sql_and_values(
97                db.get_database_backend(),
98                sql,
99                values,
100            ))
101            .await?;
102
103        if result.rows_affected() == 0 {
104            return Err(StaleObjectError.into());
105        }
106
107        self.increment_lock_version();
108        self.set_record_state(RecordState::Persisted);
109        Ok(())
110    }
111
112    /// Destroys the record while enforcing the current optimistic lock version.
113    async fn destroy_with_optimistic_lock(
114        &mut self,
115        db: &DatabaseConnection,
116    ) -> Result<(), RecordError>
117    where
118        Self: Sized + AsyncPersistence,
119        <Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
120    {
121        if self.destroyed() {
122            return Err(RecordError::NotSaved);
123        }
124
125        if !Self::locking_enabled() {
126            return AsyncPersistence::destroy(self, db).await;
127        }
128
129        let id = self.id().ok_or(RecordError::NotSaved)?;
130        let result = db
131            .execute_raw(Statement::from_sql_and_values(
132                db.get_database_backend(),
133                format!(
134                    "DELETE FROM {table} WHERE {primary_key} = ? AND {locking_column} = ?",
135                    table = Self::table_name(),
136                    primary_key = Self::primary_key_name(),
137                    locking_column = Self::locking_column(),
138                ),
139                [id.into(), self.lock_version().into()],
140            ))
141            .await?;
142
143        if result.rows_affected() == 0 {
144            return Err(StaleObjectError.into());
145        }
146
147        self.set_record_state(RecordState::Destroyed);
148        Ok(())
149    }
150}
151
152#[allow(dead_code)]
153fn serialized_assignments<T: Serialize>(
154    record: &T,
155    excluded_columns: &[&str],
156) -> Result<Vec<(String, sea_orm::Value)>, RecordError> {
157    let json =
158        serde_json::to_value(record).map_err(|error| RecordError::Invalid(error.to_string()))?;
159    let object = json
160        .as_object()
161        .ok_or_else(|| RecordError::Invalid("record must serialize to a JSON object".to_owned()))?;
162
163    let mut assignments = Vec::new();
164    for (column, value) in object {
165        if excluded_columns
166            .iter()
167            .any(|excluded| excluded == &column.as_str())
168        {
169            continue;
170        }
171
172        assignments.push((column.clone(), json_to_sea_value(value)?));
173    }
174
175    Ok(assignments)
176}
177
178#[allow(dead_code)]
179fn build_update_sql<T: OptimisticLocking>(assignments: &[(String, sea_orm::Value)]) -> String {
180    let mut sql = format!("UPDATE {} SET ", T::table_name());
181    if assignments.is_empty() {
182        sql.push_str(&format!(
183            "{locking_column} = {locking_column} + 1",
184            locking_column = T::locking_column(),
185        ));
186    } else {
187        for (index, (column, _)) in assignments.iter().enumerate() {
188            if index > 0 {
189                sql.push_str(", ");
190            }
191            sql.push_str(column);
192            sql.push_str(" = ?");
193        }
194        sql.push_str(&format!(
195            ", {locking_column} = {locking_column} + 1",
196            locking_column = T::locking_column(),
197        ));
198    }
199
200    sql.push_str(&format!(
201        " WHERE {primary_key} = ? AND {locking_column} = ?",
202        primary_key = T::primary_key_name(),
203        locking_column = T::locking_column(),
204    ));
205    sql
206}
207
208/// Error returned when an optimistic lock detects stale state.
209#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
210#[error("stale object error: record was updated by another process")]
211pub struct StaleObjectError;
212
213impl From<StaleObjectError> for RecordError {
214    fn from(_: StaleObjectError) -> Self {
215        RecordError::StaleObject
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use sea_orm::{
222        ActiveValue::{NotSet, Set},
223        ConnectionTrait, Database, DatabaseConnection, EntityTrait, Statement,
224    };
225    use serde::{Deserialize, Serialize};
226    use serde_json::json;
227    use std::collections::HashMap;
228
229    use super::{OptimisticLocking, StaleObjectError};
230    use crate::{
231        base::{Record, RecordError, RecordState},
232        persistence::AsyncPersistence,
233        querying::AsyncQuerying,
234    };
235
236    mod versioned_user {
237        use sea_orm::entity::prelude::*;
238
239        #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
240        #[sea_orm(table_name = "versioned_users")]
241        pub struct Model {
242            #[sea_orm(primary_key)]
243            pub id: i32,
244            pub name: String,
245            pub email: String,
246            pub lock_version: i64,
247        }
248
249        #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
250        pub enum Relation {}
251
252        impl ActiveModelBehavior for ActiveModel {}
253    }
254
255    fn default_record_state() -> RecordState {
256        RecordState::New
257    }
258
259    #[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
260    #[serde(deny_unknown_fields)]
261    struct VersionedUser {
262        id: Option<i64>,
263        name: String,
264        email: String,
265        lock_version: i64,
266        #[serde(skip, default = "default_record_state")]
267        state: RecordState,
268    }
269
270    #[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
271    #[serde(deny_unknown_fields)]
272    struct UnlockedVersionedUser {
273        id: Option<i64>,
274        name: String,
275        email: String,
276        lock_version: i64,
277        #[serde(skip, default = "default_record_state")]
278        state: RecordState,
279    }
280
281    macro_rules! impl_versioned_record {
282        ($name:ident, $lock_optimistically:expr) => {
283            impl Record for $name {
284                type Entity = versioned_user::Entity;
285
286                fn table_name() -> &'static str {
287                    "versioned_users"
288                }
289
290                fn id(&self) -> Option<i64> {
291                    self.id
292                }
293
294                fn record_state(&self) -> RecordState {
295                    self.state
296                }
297
298                fn set_record_state(&mut self, state: RecordState) {
299                    self.state = state;
300                }
301
302                fn from_sea_model(model: <Self::Entity as EntityTrait>::Model) -> Self {
303                    Self {
304                        id: Some(i64::from(model.id)),
305                        name: model.name,
306                        email: model.email,
307                        lock_version: model.lock_version,
308                        state: RecordState::Persisted,
309                    }
310                }
311
312                fn to_active_model(&self) -> <Self::Entity as EntityTrait>::ActiveModel {
313                    versioned_user::ActiveModel {
314                        id: match self.id.and_then(|value| i32::try_from(value).ok()) {
315                            Some(value) => Set(value),
316                            None => NotSet,
317                        },
318                        name: Set(self.name.clone()),
319                        email: Set(self.email.clone()),
320                        lock_version: Set(self.lock_version),
321                    }
322                }
323            }
324
325            impl AsyncPersistence for $name {}
326            impl AsyncQuerying for $name {}
327
328            impl OptimisticLocking for $name {
329                fn lock_version(&self) -> i64 {
330                    self.lock_version
331                }
332
333                fn increment_lock_version(&mut self) {
334                    self.lock_version += 1;
335                }
336
337                fn lock_optimistically() -> bool {
338                    $lock_optimistically
339                }
340            }
341        };
342    }
343
344    impl_versioned_record!(VersionedUser, true);
345    impl_versioned_record!(UnlockedVersionedUser, false);
346
347    async fn setup_db() -> DatabaseConnection {
348        let db = Database::connect("sqlite::memory:")
349            .await
350            .expect("in-memory sqlite connection should succeed");
351        let schema = sea_orm::Schema::new(db.get_database_backend());
352        db.execute(&schema.create_table_from_entity(versioned_user::Entity))
353            .await
354            .expect("versioned_users table should be created");
355        db
356    }
357
358    async fn create_user(db: &DatabaseConnection, name: &str, email: &str) -> VersionedUser {
359        VersionedUser::create(
360            HashMap::from([
361                ("name".to_owned(), json!(name)),
362                ("email".to_owned(), json!(email)),
363            ]),
364            db,
365        )
366        .await
367        .expect("create should succeed")
368    }
369
370    async fn db_lock_version(db: &DatabaseConnection, id: i64) -> i64 {
371        db.query_one_raw(Statement::from_sql_and_values(
372            db.get_database_backend(),
373            "SELECT lock_version FROM versioned_users WHERE id = ?".to_owned(),
374            [id.into()],
375        ))
376        .await
377        .expect("lock-version query should succeed")
378        .expect("row should exist")
379        .try_get("", "lock_version")
380        .expect("lock_version should be readable")
381    }
382
383    #[test]
384    fn locking_enabled_defaults_to_true() {
385        assert!(VersionedUser::locking_enabled());
386        assert!(VersionedUser::lock_optimistically());
387    }
388
389    #[test]
390    fn lock_optimistically_can_disable_locking() {
391        assert!(!UnlockedVersionedUser::lock_optimistically());
392        assert!(!UnlockedVersionedUser::locking_enabled());
393    }
394
395    #[test]
396    fn stale_object_error_display_is_descriptive() {
397        assert_eq!(
398            StaleObjectError.to_string(),
399            "stale object error: record was updated by another process"
400        );
401    }
402
403    #[test]
404    fn stale_object_error_converts_to_record_error() {
405        let error: RecordError = StaleObjectError.into();
406        assert!(matches!(error, RecordError::StaleObject));
407    }
408
409    #[test]
410    fn check_lock_version_accepts_matching_version() {
411        let user = VersionedUser {
412            lock_version: 3,
413            ..Default::default()
414        };
415
416        user.check_lock_version(3)
417            .expect("matching version should succeed");
418    }
419
420    #[test]
421    fn check_lock_version_rejects_stale_version() {
422        let user = VersionedUser {
423            lock_version: 4,
424            ..Default::default()
425        };
426
427        let error = user
428            .check_lock_version(3)
429            .expect_err("mismatched version should fail");
430        assert_eq!(error, StaleObjectError);
431    }
432
433    #[test]
434    fn increment_lock_version_updates_counter() {
435        let mut user = VersionedUser::default();
436
437        user.increment_lock_version();
438        user.increment_lock_version();
439
440        assert_eq!(user.lock_version(), 2);
441    }
442
443    #[tokio::test]
444    async fn current_lock_version_reads_database_value() {
445        let db = setup_db().await;
446        let user = create_user(&db, "Alice", "alice@example.com").await;
447
448        assert_eq!(user.current_lock_version(&db).await.unwrap(), 0);
449    }
450
451    #[tokio::test]
452    async fn save_with_optimistic_lock_inserts_new_record_at_version_zero() {
453        let db = setup_db().await;
454        let mut user = VersionedUser {
455            name: "Alice".to_owned(),
456            email: "alice@example.com".to_owned(),
457            ..Default::default()
458        };
459
460        user.save_with_optimistic_lock(&db)
461            .await
462            .expect("insert should succeed");
463
464        assert!(user.persisted());
465        assert_eq!(user.lock_version(), 0);
466        assert_eq!(db_lock_version(&db, user.id().unwrap()).await, 0);
467    }
468
469    #[tokio::test]
470    async fn save_with_optimistic_lock_updates_row_and_bumps_version() {
471        let db = setup_db().await;
472        let mut user = create_user(&db, "Alice", "alice@example.com").await;
473        user.name = "Alicia".to_owned();
474
475        user.save_with_optimistic_lock(&db)
476            .await
477            .expect("optimistic update should succeed");
478
479        let reloaded = VersionedUser::find(user.id().unwrap(), &db)
480            .await
481            .expect("reloaded row should exist");
482        assert_eq!(user.lock_version(), 1);
483        assert_eq!(reloaded.name, "Alicia");
484        assert_eq!(reloaded.lock_version(), 1);
485    }
486
487    #[tokio::test]
488    async fn save_with_optimistic_lock_rejects_stale_version() {
489        let db = setup_db().await;
490        let mut fresh = create_user(&db, "Alice", "alice@example.com").await;
491        let mut stale = VersionedUser::find(fresh.id().unwrap(), &db)
492            .await
493            .expect("second copy should exist");
494
495        fresh.name = "Alicia".to_owned();
496        fresh
497            .save_with_optimistic_lock(&db)
498            .await
499            .expect("first update should succeed");
500        stale.name = "Outdated".to_owned();
501
502        let error = stale
503            .save_with_optimistic_lock(&db)
504            .await
505            .expect_err("stale update should fail");
506
507        assert!(matches!(error, RecordError::StaleObject));
508        assert_eq!(stale.lock_version(), 0);
509        assert_eq!(db_lock_version(&db, fresh.id().unwrap()).await, 1);
510    }
511
512    #[tokio::test]
513    async fn save_with_optimistic_lock_detects_external_version_bump() {
514        let db = setup_db().await;
515        let mut user = create_user(&db, "Alice", "alice@example.com").await;
516
517        db.execute_unprepared("UPDATE versioned_users SET lock_version = 2 WHERE id = 1")
518            .await
519            .expect("direct version bump should succeed");
520
521        let error = user
522            .save_with_optimistic_lock(&db)
523            .await
524            .expect_err("external version bump should cause stale error");
525        assert!(matches!(error, RecordError::StaleObject));
526        assert_eq!(user.lock_version(), 0);
527    }
528
529    #[tokio::test]
530    async fn save_with_optimistic_lock_returns_not_saved_for_destroyed_records() {
531        let db = setup_db().await;
532        let mut user = create_user(&db, "Alice", "alice@example.com").await;
533        user.set_record_state(RecordState::Destroyed);
534
535        let error = user
536            .save_with_optimistic_lock(&db)
537            .await
538            .expect_err("destroyed records cannot be saved");
539        assert!(matches!(error, RecordError::NotSaved));
540    }
541
542    #[tokio::test]
543    async fn destroy_with_optimistic_lock_deletes_matching_version() {
544        let db = setup_db().await;
545        let mut user = create_user(&db, "Alice", "alice@example.com").await;
546
547        user.destroy_with_optimistic_lock(&db)
548            .await
549            .expect("destroy should succeed");
550
551        assert!(user.destroyed());
552        assert_eq!(VersionedUser::count(&db).await.unwrap(), 0);
553    }
554
555    #[tokio::test]
556    async fn destroy_with_optimistic_lock_rejects_stale_record() {
557        let db = setup_db().await;
558        let fresh = create_user(&db, "Alice", "alice@example.com").await;
559        let mut stale = VersionedUser::find(fresh.id().unwrap(), &db)
560            .await
561            .expect("stale copy should exist");
562
563        db.execute_unprepared("UPDATE versioned_users SET lock_version = 1 WHERE id = 1")
564            .await
565            .expect("direct version bump should succeed");
566
567        let error = stale
568            .destroy_with_optimistic_lock(&db)
569            .await
570            .expect_err("stale destroy should fail");
571        assert!(matches!(error, RecordError::StaleObject));
572        assert!(stale.persisted());
573        assert_eq!(VersionedUser::count(&db).await.unwrap(), 1);
574    }
575
576    #[tokio::test]
577    async fn destroy_with_optimistic_lock_returns_not_saved_without_id() {
578        let db = setup_db().await;
579        let mut user = VersionedUser::default();
580
581        let error = user
582            .destroy_with_optimistic_lock(&db)
583            .await
584            .expect_err("new records cannot be destroyed optimistically");
585        assert!(matches!(error, RecordError::NotSaved));
586    }
587
588    #[tokio::test]
589    async fn save_with_optimistic_lock_uses_regular_save_when_disabled() {
590        let db = setup_db().await;
591        let mut user = UnlockedVersionedUser {
592            name: "Alice".to_owned(),
593            email: "alice@example.com".to_owned(),
594            ..Default::default()
595        };
596
597        user.save_with_optimistic_lock(&db)
598            .await
599            .expect("save should delegate when locking is disabled");
600        user.name = "Alicia".to_owned();
601        user.save_with_optimistic_lock(&db)
602            .await
603            .expect("update should delegate when locking is disabled");
604
605        let reloaded = UnlockedVersionedUser::find(user.id().unwrap(), &db)
606            .await
607            .expect("row should still reload");
608        assert_eq!(reloaded.name, "Alicia");
609        assert_eq!(reloaded.lock_version(), 0);
610    }
611
612    #[tokio::test]
613    async fn destroy_with_optimistic_lock_uses_regular_destroy_when_disabled() {
614        let db = setup_db().await;
615        let mut user = UnlockedVersionedUser::create(
616            HashMap::from([
617                ("name".to_owned(), json!("Alice")),
618                ("email".to_owned(), json!("alice@example.com")),
619            ]),
620            &db,
621        )
622        .await
623        .expect("create should succeed");
624
625        user.destroy_with_optimistic_lock(&db)
626            .await
627            .expect("destroy should delegate when locking is disabled");
628
629        assert!(user.destroyed());
630        assert_eq!(UnlockedVersionedUser::count(&db).await.unwrap(), 0);
631    }
632
633    #[tokio::test]
634    async fn multiple_successive_optimistic_saves_keep_advancing_version() {
635        let db = setup_db().await;
636        let mut user = create_user(&db, "Alice", "alice@example.com").await;
637
638        user.name = "Alicia".to_owned();
639        user.save_with_optimistic_lock(&db).await.unwrap();
640        user.email = "alicia@example.com".to_owned();
641        user.save_with_optimistic_lock(&db).await.unwrap();
642
643        assert_eq!(user.lock_version(), 2);
644        let reloaded = VersionedUser::find(user.id().unwrap(), &db).await.unwrap();
645        assert_eq!(reloaded.name, "Alicia");
646        assert_eq!(reloaded.email, "alicia@example.com");
647        assert_eq!(reloaded.lock_version(), 2);
648    }
649}