Skip to main content

sql_orm/
active_record.rs

1use crate::{AuditEntity, DbContextEntitySet, DbSetQuery, SoftDeleteEntity, TenantScopedEntity};
2use core::future::Future;
3use sql_orm_core::{ColumnValue, Entity, FromRow, OrmError, SqlTypeMapping, SqlValue};
4
5#[doc(hidden)]
6pub trait EntityPrimaryKey: Entity {
7    fn primary_key_value(&self) -> Result<SqlValue, OrmError>;
8}
9
10#[doc(hidden)]
11pub enum EntityPersistMode {
12    Insert,
13    InsertOrUpdate(SqlValue),
14    Update(SqlValue),
15}
16
17#[doc(hidden)]
18pub trait EntityPersist: Entity {
19    fn persist_mode(&self) -> Result<EntityPersistMode, OrmError>;
20    fn insert_values(&self) -> Vec<ColumnValue>;
21    fn update_changes(&self) -> Vec<ColumnValue>;
22    fn concurrency_token(&self) -> Result<Option<SqlValue>, OrmError>;
23    fn sync_persisted(&mut self, persisted: Self);
24
25    #[doc(hidden)]
26    fn has_persisted_changes(original: &Self, current: &Self) -> bool {
27        original.update_changes() != current.update_changes()
28    }
29}
30
31/// Convenience Active Record style API for entities.
32///
33/// Every `Entity` implements this trait. The methods delegate to the `DbSet`
34/// declared on a `DbContext`, so Active Record remains a thin convenience
35/// layer over the same query, insert, update, delete, tenant, audit, and
36/// soft-delete pipelines used by the explicit context API.
37pub trait ActiveRecord: Entity + Sized {
38    /// Starts a query for this entity through the context's `DbSet<Self>`.
39    fn query<C>(db: &C) -> DbSetQuery<Self>
40    where
41        C: DbContextEntitySet<Self>,
42        Self: TenantScopedEntity,
43    {
44        db.db_set().query()
45    }
46
47    /// Finds one entity by single-column primary key through the context.
48    fn find<C, K>(db: &C, key: K) -> impl Future<Output = Result<Option<Self>, OrmError>> + Send
49    where
50        C: DbContextEntitySet<Self>,
51        Self: FromRow + Send + SoftDeleteEntity + TenantScopedEntity,
52        K: SqlTypeMapping + Send,
53    {
54        db.db_set().find(key)
55    }
56
57    /// Deletes this entity through the context's `DbSet<Self>`.
58    ///
59    /// Entities with `soft_delete` use logical delete. Entities with
60    /// rowversion participate in the same concurrency-conflict detection as
61    /// the explicit `DbSet` delete path.
62    fn delete<C>(&self, db: &C) -> impl Future<Output = Result<bool, OrmError>> + Send
63    where
64        C: DbContextEntitySet<Self> + Sync,
65        Self: EntityPrimaryKey
66            + EntityPersist
67            + FromRow
68            + Send
69            + SoftDeleteEntity
70            + TenantScopedEntity,
71    {
72        let key = <Self as EntityPrimaryKey>::primary_key_value(self);
73        let concurrency_token = <Self as EntityPersist>::concurrency_token(self);
74
75        async move {
76            db.db_set()
77                .delete_by_sql_value(key?, concurrency_token?)
78                .await
79        }
80    }
81
82    /// Inserts or updates this entity according to the derived persistence
83    /// strategy.
84    ///
85    /// The method syncs the in-memory entity with the persisted row returned
86    /// by SQL Server.
87    fn save<C>(&mut self, db: &C) -> impl Future<Output = Result<(), OrmError>> + Send
88    where
89        C: DbContextEntitySet<Self> + Sync,
90        Self: AuditEntity + EntityPersist + FromRow + Send + SoftDeleteEntity + TenantScopedEntity,
91    {
92        async move {
93            match <Self as EntityPersist>::persist_mode(self)? {
94                EntityPersistMode::Insert => {
95                    let persisted = db.db_set().insert_entity(self).await?;
96                    <Self as EntityPersist>::sync_persisted(self, persisted);
97                    Ok(())
98                }
99                EntityPersistMode::InsertOrUpdate(key) => {
100                    if db
101                        .db_set()
102                        .exists_by_sql_value_internal(key.clone())
103                        .await?
104                    {
105                        if let Some(persisted) = db
106                            .db_set()
107                            .update_entity_by_sql_value(
108                                key,
109                                self,
110                                <Self as EntityPersist>::concurrency_token(self)?,
111                            )
112                            .await?
113                        {
114                            <Self as EntityPersist>::sync_persisted(self, persisted);
115                        } else {
116                            return Err(OrmError::new(
117                                "ActiveRecord save could not update a row for the current primary key",
118                            ));
119                        }
120                    } else {
121                        let persisted = db.db_set().insert_entity(self).await?;
122                        <Self as EntityPersist>::sync_persisted(self, persisted);
123                    }
124
125                    Ok(())
126                }
127                EntityPersistMode::Update(key) => {
128                    let persisted = db
129                        .db_set()
130                        .update_entity_by_sql_value(
131                            key,
132                            self,
133                            <Self as EntityPersist>::concurrency_token(self)?,
134                        )
135                        .await?
136                        .ok_or_else(|| {
137                            OrmError::new(
138                                "ActiveRecord save could not update a row for the current primary key",
139                            )
140                        })?;
141                    <Self as EntityPersist>::sync_persisted(self, persisted);
142                    Ok(())
143                }
144            }
145        }
146    }
147}
148
149impl<E: Entity> ActiveRecord for E {}
150
151#[cfg(test)]
152mod tests {
153    use super::{ActiveRecord, EntityPersist, EntityPersistMode, EntityPrimaryKey};
154    use crate::{
155        AuditEntity, DbContext, DbContextEntitySet, DbSet, SoftDeleteEntity, TenantScopedEntity,
156        Tracked,
157    };
158    use sql_orm_core::{
159        ColumnMetadata, ColumnValue, Entity, EntityMetadata, EntityPolicyMetadata, FromRow,
160        OrmError, PrimaryKeyMetadata, Row, SqlServerType,
161    };
162    use sql_orm_query::SelectQuery;
163
164    #[derive(Debug, Clone, PartialEq)]
165    struct TestEntity {
166        id: i64,
167        name: String,
168    }
169
170    static TEST_ENTITY_COLUMNS: [ColumnMetadata; 2] = [
171        ColumnMetadata {
172            rust_field: "id",
173            column_name: "id",
174            renamed_from: None,
175            sql_type: SqlServerType::BigInt,
176            nullable: false,
177            primary_key: true,
178            identity: None,
179            default_sql: None,
180            computed_sql: None,
181            rowversion: false,
182            insertable: true,
183            updatable: false,
184            max_length: None,
185            precision: None,
186            scale: None,
187        },
188        ColumnMetadata {
189            rust_field: "name",
190            column_name: "name",
191            renamed_from: None,
192            sql_type: SqlServerType::NVarChar,
193            nullable: false,
194            primary_key: false,
195            identity: None,
196            default_sql: None,
197            computed_sql: None,
198            rowversion: false,
199            insertable: true,
200            updatable: true,
201            max_length: Some(120),
202            precision: None,
203            scale: None,
204        },
205    ];
206
207    static TEST_ENTITY_METADATA: EntityMetadata = EntityMetadata {
208        rust_name: "TestEntity",
209        schema: "dbo",
210        table: "test_entities",
211        renamed_from: None,
212        columns: &TEST_ENTITY_COLUMNS,
213        primary_key: PrimaryKeyMetadata {
214            name: None,
215            columns: &["id"],
216        },
217        indexes: &[],
218        foreign_keys: &[],
219        navigations: &[],
220    };
221
222    impl Entity for TestEntity {
223        fn metadata() -> &'static EntityMetadata {
224            &TEST_ENTITY_METADATA
225        }
226    }
227
228    impl SoftDeleteEntity for TestEntity {
229        fn soft_delete_policy() -> Option<EntityPolicyMetadata> {
230            None
231        }
232    }
233
234    impl AuditEntity for TestEntity {
235        fn audit_policy() -> Option<EntityPolicyMetadata> {
236            None
237        }
238    }
239
240    impl TenantScopedEntity for TestEntity {
241        fn tenant_policy() -> Option<EntityPolicyMetadata> {
242            None
243        }
244    }
245
246    impl FromRow for TestEntity {
247        fn from_row<R: Row>(_row: &R) -> Result<Self, OrmError> {
248            Ok(Self {
249                id: 7,
250                name: "Persisted".to_string(),
251            })
252        }
253    }
254
255    impl EntityPrimaryKey for TestEntity {
256        fn primary_key_value(&self) -> Result<sql_orm_core::SqlValue, OrmError> {
257            Ok(sql_orm_core::SqlValue::I64(self.id))
258        }
259    }
260
261    impl EntityPersist for TestEntity {
262        fn persist_mode(&self) -> Result<EntityPersistMode, OrmError> {
263            Ok(EntityPersistMode::Update(sql_orm_core::SqlValue::I64(
264                self.id,
265            )))
266        }
267
268        fn insert_values(&self) -> Vec<ColumnValue> {
269            vec![ColumnValue::new(
270                "name",
271                sql_orm_core::SqlValue::String(self.name.clone()),
272            )]
273        }
274
275        fn update_changes(&self) -> Vec<ColumnValue> {
276            vec![ColumnValue::new(
277                "name",
278                sql_orm_core::SqlValue::String(self.name.clone()),
279            )]
280        }
281
282        fn concurrency_token(&self) -> Result<Option<sql_orm_core::SqlValue>, OrmError> {
283            Ok(None)
284        }
285
286        fn sync_persisted(&mut self, persisted: Self) {
287            *self = persisted;
288        }
289    }
290
291    struct DummyContext {
292        entities: DbSet<TestEntity>,
293    }
294
295    impl DbContext for DummyContext {
296        fn from_shared_connection(_connection: crate::SharedConnection) -> Self {
297            unreachable!("DummyContext is only used in disconnected unit tests")
298        }
299
300        fn shared_connection(&self) -> crate::SharedConnection {
301            panic!("DummyContext is only used in disconnected unit tests")
302        }
303
304        fn tracking_registry(&self) -> crate::TrackingRegistryHandle {
305            self.entities.tracking_registry()
306        }
307    }
308
309    impl DbContextEntitySet<TestEntity> for DummyContext {
310        fn db_set(&self) -> &DbSet<TestEntity> {
311            &self.entities
312        }
313    }
314
315    #[test]
316    fn active_record_query_delegates_to_typed_dbset() {
317        let context = DummyContext {
318            entities: DbSet::<TestEntity>::disconnected(),
319        };
320
321        let query = TestEntity::query(&context);
322
323        assert_eq!(
324            query.into_select_query(),
325            SelectQuery::from_entity::<TestEntity>()
326        );
327    }
328
329    #[test]
330    fn active_record_trait_is_available_for_entities() {
331        fn require_active_record<E: ActiveRecord>() {}
332
333        require_active_record::<TestEntity>();
334    }
335
336    #[tokio::test]
337    async fn tracked_save_unchanged_is_noop_without_dereferencing_to_active_record() {
338        let context = DummyContext {
339            entities: DbSet::<TestEntity>::disconnected(),
340        };
341        let mut tracked = Tracked::from_loaded(TestEntity {
342            id: 7,
343            name: "Tracked".to_string(),
344        });
345
346        tracked.save(&context).await.unwrap();
347
348        assert_eq!(tracked.state(), crate::EntityState::Unchanged);
349        assert_eq!(tracked.original(), tracked.current());
350    }
351
352    #[tokio::test]
353    async fn tracked_save_unchanged_registered_entry_remains_tracked() {
354        let context = DummyContext {
355            entities: DbSet::<TestEntity>::disconnected(),
356        };
357        let registry = context.entities.tracking_registry();
358        let mut tracked = Tracked::from_loaded(TestEntity {
359            id: 7,
360            name: "Tracked".to_string(),
361        });
362        tracked
363            .attach_registry_loaded(registry.clone(), sql_orm_core::SqlValue::I64(7))
364            .unwrap();
365
366        tracked.save(&context).await.unwrap();
367
368        assert_eq!(tracked.state(), crate::EntityState::Unchanged);
369        assert_eq!(registry.entry_count(), 1);
370        assert_eq!(
371            registry.registrations()[0].state,
372            crate::EntityState::Unchanged
373        );
374    }
375
376    #[tokio::test]
377    async fn tracked_save_deleted_returns_stable_error_before_active_record() {
378        let context = DummyContext {
379            entities: DbSet::<TestEntity>::disconnected(),
380        };
381        let mut tracked = Tracked::from_loaded(TestEntity {
382            id: 7,
383            name: "Tracked".to_string(),
384        });
385        context.entities.remove_tracked(&mut tracked);
386
387        let error = tracked.save(&context).await.unwrap_err();
388
389        assert_eq!(
390            error.message(),
391            "tracked deleted entities cannot be saved; detach them or persist deletion"
392        );
393    }
394
395    #[tokio::test]
396    async fn tracked_save_deleted_registered_entry_keeps_pending_delete_after_error() {
397        let context = DummyContext {
398            entities: DbSet::<TestEntity>::disconnected(),
399        };
400        let registry = context.entities.tracking_registry();
401        let mut tracked = Tracked::from_loaded(TestEntity {
402            id: 7,
403            name: "Tracked".to_string(),
404        });
405        tracked
406            .attach_registry_loaded(registry.clone(), sql_orm_core::SqlValue::I64(7))
407            .unwrap();
408        context.entities.remove_tracked(&mut tracked);
409
410        let error = tracked.save(&context).await.unwrap_err();
411
412        assert_eq!(
413            error.message(),
414            "tracked deleted entities cannot be saved; detach them or persist deletion"
415        );
416        assert_eq!(tracked.state(), crate::EntityState::Deleted);
417        assert_eq!(registry.entry_count(), 1);
418        assert_eq!(
419            registry.registrations()[0].state,
420            crate::EntityState::Deleted
421        );
422    }
423
424    #[tokio::test]
425    async fn tracked_delete_added_cancels_local_insert_without_active_record() {
426        let context = DummyContext {
427            entities: DbSet::<TestEntity>::disconnected(),
428        };
429        let registry = context.entities.tracking_registry();
430        let mut tracked = context.entities.add_tracked(TestEntity {
431            id: 0,
432            name: "Pending".to_string(),
433        });
434
435        let deleted = tracked.delete(&context).await.unwrap();
436
437        assert!(!deleted);
438        assert_eq!(tracked.state(), crate::EntityState::Deleted);
439        assert_eq!(registry.entry_count(), 0);
440    }
441
442    #[tokio::test]
443    async fn tracked_delete_deleted_entry_is_idempotent_without_active_record() {
444        let context = DummyContext {
445            entities: DbSet::<TestEntity>::disconnected(),
446        };
447        let registry = context.entities.tracking_registry();
448        let mut tracked = context.entities.add_tracked(TestEntity {
449            id: 0,
450            name: "Pending".to_string(),
451        });
452        tracked.delete(&context).await.unwrap();
453
454        let deleted = tracked.delete(&context).await.unwrap();
455
456        assert!(!deleted);
457        assert_eq!(tracked.state(), crate::EntityState::Deleted);
458        assert_eq!(registry.entry_count(), 0);
459    }
460
461    #[test]
462    fn active_record_find_reuses_dbset_error_contract() {
463        let context = DummyContext {
464            entities: DbSet::<TestEntity>::disconnected(),
465        };
466
467        let runtime = tokio::runtime::Runtime::new().expect("tokio runtime");
468        let error = match runtime.block_on(TestEntity::find(&context, 1_i64)) {
469            Ok(value) => panic!("expected disconnected ActiveRecord::find to fail, got {value:?}"),
470            Err(error) => error,
471        };
472
473        assert_eq!(
474            error,
475            OrmError::new("DbSetQuery requires an initialized shared connection")
476        );
477    }
478
479    #[test]
480    fn active_record_delete_reuses_dbset_error_contract() {
481        let context = DummyContext {
482            entities: DbSet::<TestEntity>::disconnected(),
483        };
484        let entity = TestEntity {
485            id: 7,
486            name: "Ana".to_string(),
487        };
488
489        let runtime = tokio::runtime::Runtime::new().expect("tokio runtime");
490        let error = match runtime.block_on(entity.delete(&context)) {
491            Ok(value) => {
492                panic!("expected disconnected ActiveRecord::delete to fail, got {value:?}")
493            }
494            Err(error) => error,
495        };
496
497        assert_eq!(
498            error,
499            OrmError::new("DbSet requires an initialized shared connection")
500        );
501    }
502
503    #[test]
504    fn active_record_save_reuses_dbset_error_contract() {
505        let context = DummyContext {
506            entities: DbSet::<TestEntity>::disconnected(),
507        };
508        let mut entity = TestEntity {
509            id: 7,
510            name: "Ana".to_string(),
511        };
512
513        let runtime = tokio::runtime::Runtime::new().expect("tokio runtime");
514        let error = match runtime.block_on(entity.save(&context)) {
515            Ok(()) => panic!("expected disconnected ActiveRecord::save to fail"),
516            Err(error) => error,
517        };
518
519        assert_eq!(
520            error,
521            OrmError::new("DbSet requires an initialized shared connection")
522        );
523    }
524}