1use crate::{AuditEntity, DbContextEntitySet, DbSetQuery, SoftDeleteEntity, TenantScopedEntity};
2use core::future::Future;
3use sql_orm_core::{ColumnValue, Entity, FromRow, OrmError, SqlTypeMapping, SqlValue};
4
5#[doc(hidden)]
6pub trait EntityPrimaryKey: Entity {
7 fn primary_key_value(&self) -> Result<SqlValue, OrmError>;
8}
9
10#[doc(hidden)]
11pub enum EntityPersistMode {
12 Insert,
13 InsertOrUpdate(SqlValue),
14 Update(SqlValue),
15}
16
17#[doc(hidden)]
18pub trait EntityPersist: Entity {
19 fn persist_mode(&self) -> Result<EntityPersistMode, OrmError>;
20 fn insert_values(&self) -> Vec<ColumnValue>;
21 fn update_changes(&self) -> Vec<ColumnValue>;
22 fn concurrency_token(&self) -> Result<Option<SqlValue>, OrmError>;
23 fn sync_persisted(&mut self, persisted: Self);
24
25 #[doc(hidden)]
26 fn has_persisted_changes(original: &Self, current: &Self) -> bool {
27 original.update_changes() != current.update_changes()
28 }
29}
30
31pub trait ActiveRecord: Entity + Sized {
38 fn query<C>(db: &C) -> DbSetQuery<Self>
40 where
41 C: DbContextEntitySet<Self>,
42 Self: TenantScopedEntity,
43 {
44 db.db_set().query()
45 }
46
47 fn find<C, K>(db: &C, key: K) -> impl Future<Output = Result<Option<Self>, OrmError>> + Send
49 where
50 C: DbContextEntitySet<Self>,
51 Self: FromRow + Send + SoftDeleteEntity + TenantScopedEntity,
52 K: SqlTypeMapping + Send,
53 {
54 db.db_set().find(key)
55 }
56
57 fn delete<C>(&self, db: &C) -> impl Future<Output = Result<bool, OrmError>> + Send
63 where
64 C: DbContextEntitySet<Self> + Sync,
65 Self: EntityPrimaryKey
66 + EntityPersist
67 + FromRow
68 + Send
69 + SoftDeleteEntity
70 + TenantScopedEntity,
71 {
72 let key = <Self as EntityPrimaryKey>::primary_key_value(self);
73 let concurrency_token = <Self as EntityPersist>::concurrency_token(self);
74
75 async move {
76 db.db_set()
77 .delete_by_sql_value(key?, concurrency_token?)
78 .await
79 }
80 }
81
82 fn save<C>(&mut self, db: &C) -> impl Future<Output = Result<(), OrmError>> + Send
88 where
89 C: DbContextEntitySet<Self> + Sync,
90 Self: AuditEntity + EntityPersist + FromRow + Send + SoftDeleteEntity + TenantScopedEntity,
91 {
92 async move {
93 match <Self as EntityPersist>::persist_mode(self)? {
94 EntityPersistMode::Insert => {
95 let persisted = db.db_set().insert_entity(self).await?;
96 <Self as EntityPersist>::sync_persisted(self, persisted);
97 Ok(())
98 }
99 EntityPersistMode::InsertOrUpdate(key) => {
100 if db
101 .db_set()
102 .exists_by_sql_value_internal(key.clone())
103 .await?
104 {
105 if let Some(persisted) = db
106 .db_set()
107 .update_entity_by_sql_value(
108 key,
109 self,
110 <Self as EntityPersist>::concurrency_token(self)?,
111 )
112 .await?
113 {
114 <Self as EntityPersist>::sync_persisted(self, persisted);
115 } else {
116 return Err(OrmError::new(
117 "ActiveRecord save could not update a row for the current primary key",
118 ));
119 }
120 } else {
121 let persisted = db.db_set().insert_entity(self).await?;
122 <Self as EntityPersist>::sync_persisted(self, persisted);
123 }
124
125 Ok(())
126 }
127 EntityPersistMode::Update(key) => {
128 let persisted = db
129 .db_set()
130 .update_entity_by_sql_value(
131 key,
132 self,
133 <Self as EntityPersist>::concurrency_token(self)?,
134 )
135 .await?
136 .ok_or_else(|| {
137 OrmError::new(
138 "ActiveRecord save could not update a row for the current primary key",
139 )
140 })?;
141 <Self as EntityPersist>::sync_persisted(self, persisted);
142 Ok(())
143 }
144 }
145 }
146 }
147}
148
149impl<E: Entity> ActiveRecord for E {}
150
151#[cfg(test)]
152mod tests {
153 use super::{ActiveRecord, EntityPersist, EntityPersistMode, EntityPrimaryKey};
154 use crate::{
155 AuditEntity, DbContext, DbContextEntitySet, DbSet, SoftDeleteEntity, TenantScopedEntity,
156 Tracked,
157 };
158 use sql_orm_core::{
159 ColumnMetadata, ColumnValue, Entity, EntityMetadata, EntityPolicyMetadata, FromRow,
160 OrmError, PrimaryKeyMetadata, Row, SqlServerType,
161 };
162 use sql_orm_query::SelectQuery;
163
164 #[derive(Debug, Clone, PartialEq)]
165 struct TestEntity {
166 id: i64,
167 name: String,
168 }
169
170 static TEST_ENTITY_COLUMNS: [ColumnMetadata; 2] = [
171 ColumnMetadata {
172 rust_field: "id",
173 column_name: "id",
174 renamed_from: None,
175 sql_type: SqlServerType::BigInt,
176 nullable: false,
177 primary_key: true,
178 identity: None,
179 default_sql: None,
180 computed_sql: None,
181 rowversion: false,
182 insertable: true,
183 updatable: false,
184 max_length: None,
185 precision: None,
186 scale: None,
187 },
188 ColumnMetadata {
189 rust_field: "name",
190 column_name: "name",
191 renamed_from: None,
192 sql_type: SqlServerType::NVarChar,
193 nullable: false,
194 primary_key: false,
195 identity: None,
196 default_sql: None,
197 computed_sql: None,
198 rowversion: false,
199 insertable: true,
200 updatable: true,
201 max_length: Some(120),
202 precision: None,
203 scale: None,
204 },
205 ];
206
207 static TEST_ENTITY_METADATA: EntityMetadata = EntityMetadata {
208 rust_name: "TestEntity",
209 schema: "dbo",
210 table: "test_entities",
211 renamed_from: None,
212 columns: &TEST_ENTITY_COLUMNS,
213 primary_key: PrimaryKeyMetadata {
214 name: None,
215 columns: &["id"],
216 },
217 indexes: &[],
218 foreign_keys: &[],
219 navigations: &[],
220 };
221
222 impl Entity for TestEntity {
223 fn metadata() -> &'static EntityMetadata {
224 &TEST_ENTITY_METADATA
225 }
226 }
227
228 impl SoftDeleteEntity for TestEntity {
229 fn soft_delete_policy() -> Option<EntityPolicyMetadata> {
230 None
231 }
232 }
233
234 impl AuditEntity for TestEntity {
235 fn audit_policy() -> Option<EntityPolicyMetadata> {
236 None
237 }
238 }
239
240 impl TenantScopedEntity for TestEntity {
241 fn tenant_policy() -> Option<EntityPolicyMetadata> {
242 None
243 }
244 }
245
246 impl FromRow for TestEntity {
247 fn from_row<R: Row>(_row: &R) -> Result<Self, OrmError> {
248 Ok(Self {
249 id: 7,
250 name: "Persisted".to_string(),
251 })
252 }
253 }
254
255 impl EntityPrimaryKey for TestEntity {
256 fn primary_key_value(&self) -> Result<sql_orm_core::SqlValue, OrmError> {
257 Ok(sql_orm_core::SqlValue::I64(self.id))
258 }
259 }
260
261 impl EntityPersist for TestEntity {
262 fn persist_mode(&self) -> Result<EntityPersistMode, OrmError> {
263 Ok(EntityPersistMode::Update(sql_orm_core::SqlValue::I64(
264 self.id,
265 )))
266 }
267
268 fn insert_values(&self) -> Vec<ColumnValue> {
269 vec![ColumnValue::new(
270 "name",
271 sql_orm_core::SqlValue::String(self.name.clone()),
272 )]
273 }
274
275 fn update_changes(&self) -> Vec<ColumnValue> {
276 vec![ColumnValue::new(
277 "name",
278 sql_orm_core::SqlValue::String(self.name.clone()),
279 )]
280 }
281
282 fn concurrency_token(&self) -> Result<Option<sql_orm_core::SqlValue>, OrmError> {
283 Ok(None)
284 }
285
286 fn sync_persisted(&mut self, persisted: Self) {
287 *self = persisted;
288 }
289 }
290
291 struct DummyContext {
292 entities: DbSet<TestEntity>,
293 }
294
295 impl DbContext for DummyContext {
296 fn from_shared_connection(_connection: crate::SharedConnection) -> Self {
297 unreachable!("DummyContext is only used in disconnected unit tests")
298 }
299
300 fn shared_connection(&self) -> crate::SharedConnection {
301 panic!("DummyContext is only used in disconnected unit tests")
302 }
303
304 fn tracking_registry(&self) -> crate::TrackingRegistryHandle {
305 self.entities.tracking_registry()
306 }
307 }
308
309 impl DbContextEntitySet<TestEntity> for DummyContext {
310 fn db_set(&self) -> &DbSet<TestEntity> {
311 &self.entities
312 }
313 }
314
315 #[test]
316 fn active_record_query_delegates_to_typed_dbset() {
317 let context = DummyContext {
318 entities: DbSet::<TestEntity>::disconnected(),
319 };
320
321 let query = TestEntity::query(&context);
322
323 assert_eq!(
324 query.into_select_query(),
325 SelectQuery::from_entity::<TestEntity>()
326 );
327 }
328
329 #[test]
330 fn active_record_trait_is_available_for_entities() {
331 fn require_active_record<E: ActiveRecord>() {}
332
333 require_active_record::<TestEntity>();
334 }
335
336 #[tokio::test]
337 async fn tracked_save_unchanged_is_noop_without_dereferencing_to_active_record() {
338 let context = DummyContext {
339 entities: DbSet::<TestEntity>::disconnected(),
340 };
341 let mut tracked = Tracked::from_loaded(TestEntity {
342 id: 7,
343 name: "Tracked".to_string(),
344 });
345
346 tracked.save(&context).await.unwrap();
347
348 assert_eq!(tracked.state(), crate::EntityState::Unchanged);
349 assert_eq!(tracked.original(), tracked.current());
350 }
351
352 #[tokio::test]
353 async fn tracked_save_unchanged_registered_entry_remains_tracked() {
354 let context = DummyContext {
355 entities: DbSet::<TestEntity>::disconnected(),
356 };
357 let registry = context.entities.tracking_registry();
358 let mut tracked = Tracked::from_loaded(TestEntity {
359 id: 7,
360 name: "Tracked".to_string(),
361 });
362 tracked
363 .attach_registry_loaded(registry.clone(), sql_orm_core::SqlValue::I64(7))
364 .unwrap();
365
366 tracked.save(&context).await.unwrap();
367
368 assert_eq!(tracked.state(), crate::EntityState::Unchanged);
369 assert_eq!(registry.entry_count(), 1);
370 assert_eq!(
371 registry.registrations()[0].state,
372 crate::EntityState::Unchanged
373 );
374 }
375
376 #[tokio::test]
377 async fn tracked_save_deleted_returns_stable_error_before_active_record() {
378 let context = DummyContext {
379 entities: DbSet::<TestEntity>::disconnected(),
380 };
381 let mut tracked = Tracked::from_loaded(TestEntity {
382 id: 7,
383 name: "Tracked".to_string(),
384 });
385 context.entities.remove_tracked(&mut tracked);
386
387 let error = tracked.save(&context).await.unwrap_err();
388
389 assert_eq!(
390 error.message(),
391 "tracked deleted entities cannot be saved; detach them or persist deletion"
392 );
393 }
394
395 #[tokio::test]
396 async fn tracked_save_deleted_registered_entry_keeps_pending_delete_after_error() {
397 let context = DummyContext {
398 entities: DbSet::<TestEntity>::disconnected(),
399 };
400 let registry = context.entities.tracking_registry();
401 let mut tracked = Tracked::from_loaded(TestEntity {
402 id: 7,
403 name: "Tracked".to_string(),
404 });
405 tracked
406 .attach_registry_loaded(registry.clone(), sql_orm_core::SqlValue::I64(7))
407 .unwrap();
408 context.entities.remove_tracked(&mut tracked);
409
410 let error = tracked.save(&context).await.unwrap_err();
411
412 assert_eq!(
413 error.message(),
414 "tracked deleted entities cannot be saved; detach them or persist deletion"
415 );
416 assert_eq!(tracked.state(), crate::EntityState::Deleted);
417 assert_eq!(registry.entry_count(), 1);
418 assert_eq!(
419 registry.registrations()[0].state,
420 crate::EntityState::Deleted
421 );
422 }
423
424 #[tokio::test]
425 async fn tracked_delete_added_cancels_local_insert_without_active_record() {
426 let context = DummyContext {
427 entities: DbSet::<TestEntity>::disconnected(),
428 };
429 let registry = context.entities.tracking_registry();
430 let mut tracked = context.entities.add_tracked(TestEntity {
431 id: 0,
432 name: "Pending".to_string(),
433 });
434
435 let deleted = tracked.delete(&context).await.unwrap();
436
437 assert!(!deleted);
438 assert_eq!(tracked.state(), crate::EntityState::Deleted);
439 assert_eq!(registry.entry_count(), 0);
440 }
441
442 #[tokio::test]
443 async fn tracked_delete_deleted_entry_is_idempotent_without_active_record() {
444 let context = DummyContext {
445 entities: DbSet::<TestEntity>::disconnected(),
446 };
447 let registry = context.entities.tracking_registry();
448 let mut tracked = context.entities.add_tracked(TestEntity {
449 id: 0,
450 name: "Pending".to_string(),
451 });
452 tracked.delete(&context).await.unwrap();
453
454 let deleted = tracked.delete(&context).await.unwrap();
455
456 assert!(!deleted);
457 assert_eq!(tracked.state(), crate::EntityState::Deleted);
458 assert_eq!(registry.entry_count(), 0);
459 }
460
461 #[test]
462 fn active_record_find_reuses_dbset_error_contract() {
463 let context = DummyContext {
464 entities: DbSet::<TestEntity>::disconnected(),
465 };
466
467 let runtime = tokio::runtime::Runtime::new().expect("tokio runtime");
468 let error = match runtime.block_on(TestEntity::find(&context, 1_i64)) {
469 Ok(value) => panic!("expected disconnected ActiveRecord::find to fail, got {value:?}"),
470 Err(error) => error,
471 };
472
473 assert_eq!(
474 error,
475 OrmError::new("DbSetQuery requires an initialized shared connection")
476 );
477 }
478
479 #[test]
480 fn active_record_delete_reuses_dbset_error_contract() {
481 let context = DummyContext {
482 entities: DbSet::<TestEntity>::disconnected(),
483 };
484 let entity = TestEntity {
485 id: 7,
486 name: "Ana".to_string(),
487 };
488
489 let runtime = tokio::runtime::Runtime::new().expect("tokio runtime");
490 let error = match runtime.block_on(entity.delete(&context)) {
491 Ok(value) => {
492 panic!("expected disconnected ActiveRecord::delete to fail, got {value:?}")
493 }
494 Err(error) => error,
495 };
496
497 assert_eq!(
498 error,
499 OrmError::new("DbSet requires an initialized shared connection")
500 );
501 }
502
503 #[test]
504 fn active_record_save_reuses_dbset_error_contract() {
505 let context = DummyContext {
506 entities: DbSet::<TestEntity>::disconnected(),
507 };
508 let mut entity = TestEntity {
509 id: 7,
510 name: "Ana".to_string(),
511 };
512
513 let runtime = tokio::runtime::Runtime::new().expect("tokio runtime");
514 let error = match runtime.block_on(entity.save(&context)) {
515 Ok(()) => panic!("expected disconnected ActiveRecord::save to fail"),
516 Err(error) => error,
517 };
518
519 assert_eq!(
520 error,
521 OrmError::new("DbSet requires an initialized shared connection")
522 );
523 }
524}