1use crate::ObjectKey;
39use crate::change_tracker::ChangeTracker;
40use crate::flush::{FlushOrderer, FlushPlan, PendingOp};
41use serde::Serialize;
42use sqlmodel_core::{Error, Model, Value};
43use std::collections::{HashMap, HashSet};
44
45#[derive(Default)]
52pub struct UnitOfWork {
53 new_objects: Vec<TrackedInsert>,
55
56 dirty_objects: Vec<TrackedUpdate>,
58
59 deleted_objects: Vec<TrackedDelete>,
61
62 change_tracker: ChangeTracker,
64
65 orderer: FlushOrderer,
67
68 tables: HashSet<&'static str>,
70
71 table_dependencies: HashMap<&'static str, Vec<&'static str>>,
73}
74
75struct TrackedInsert {
77 key: ObjectKey,
78 table: &'static str,
79 columns: Vec<&'static str>,
80 values: Vec<Value>,
81}
82
83struct TrackedUpdate {
85 key: ObjectKey,
86 table: &'static str,
87 pk_columns: Vec<&'static str>,
88 pk_values: Vec<Value>,
89 set_columns: Vec<&'static str>,
90 set_values: Vec<Value>,
91}
92
93struct TrackedDelete {
95 key: ObjectKey,
96 table: &'static str,
97 pk_columns: Vec<&'static str>,
98 pk_values: Vec<Value>,
99}
100
101#[derive(Debug, Clone)]
103pub enum UowError {
104 CycleDetected {
106 tables: Vec<&'static str>,
108 },
109 AlreadyTracked {
111 key: ObjectKey,
113 state: &'static str,
115 },
116}
117
118impl std::fmt::Display for UowError {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 match self {
121 UowError::CycleDetected { tables } => {
122 write!(f, "Dependency cycle detected: {}", tables.join(" -> "))
123 }
124 UowError::AlreadyTracked { key, state } => {
125 write!(f, "Object {:?} already tracked as {}", key, state)
126 }
127 }
128 }
129}
130
131impl std::error::Error for UowError {}
132
133impl From<UowError> for Error {
134 fn from(e: UowError) -> Self {
135 Error::Custom(e.to_string())
136 }
137}
138
139impl UnitOfWork {
140 #[must_use]
142 pub fn new() -> Self {
143 Self::default()
144 }
145
146 pub fn register_model<T: Model>(&mut self) {
151 self.orderer.register_model::<T>();
152
153 let table = T::TABLE_NAME;
154 self.tables.insert(table);
155
156 let deps: Vec<&'static str> = T::fields()
158 .iter()
159 .filter_map(|f| f.foreign_key)
160 .filter_map(|fk| fk.split('.').next())
161 .collect();
162
163 self.table_dependencies.insert(table, deps);
164 }
165
166 pub fn track_new<T: Model + Serialize>(&mut self, model: &T, key: ObjectKey) {
170 let row = model.to_row();
171 let columns: Vec<&'static str> = row.iter().map(|(col, _)| *col).collect();
172 let values: Vec<Value> = row.into_iter().map(|(_, val)| val).collect();
173
174 self.new_objects.push(TrackedInsert {
175 key,
176 table: T::TABLE_NAME,
177 columns,
178 values,
179 });
180 }
181
182 pub fn track_dirty<T: Model + Serialize>(
186 &mut self,
187 model: &T,
188 key: ObjectKey,
189 changed_columns: Vec<&'static str>,
190 ) {
191 if changed_columns.is_empty() {
192 return;
193 }
194
195 let row = model.to_row();
196 let row_map: HashMap<&str, Value> = row.into_iter().collect();
197
198 let pk_columns: Vec<&'static str> = T::PRIMARY_KEY.to_vec();
199 let pk_values = model.primary_key_value();
200
201 let set_columns = changed_columns;
202 let set_values: Vec<Value> = set_columns
203 .iter()
204 .filter_map(|col| row_map.get(*col).cloned())
205 .collect();
206
207 self.dirty_objects.push(TrackedUpdate {
208 key,
209 table: T::TABLE_NAME,
210 pk_columns,
211 pk_values,
212 set_columns,
213 set_values,
214 });
215 }
216
217 pub fn track_dirty_auto<T: Model + Serialize>(&mut self, model: &T, key: ObjectKey) {
221 let changed = self.change_tracker.changed_fields(&key, model);
222 if !changed.is_empty() {
223 self.track_dirty(model, key, changed);
224 }
225 }
226
227 pub fn track_deleted<T: Model>(&mut self, model: &T, key: ObjectKey) {
231 let pk_columns: Vec<&'static str> = T::PRIMARY_KEY.to_vec();
232 let pk_values = model.primary_key_value();
233
234 self.deleted_objects.push(TrackedDelete {
235 key,
236 table: T::TABLE_NAME,
237 pk_columns,
238 pk_values,
239 });
240 }
241
242 pub fn snapshot<T: Model + Serialize>(&mut self, key: ObjectKey, model: &T) {
244 self.change_tracker.snapshot(key, model);
245 }
246
247 pub fn is_dirty<T: Model + Serialize>(&self, key: &ObjectKey, model: &T) -> bool {
249 self.change_tracker.is_dirty(key, model)
250 }
251
252 pub fn changed_fields<T: Model + Serialize>(
254 &self,
255 key: &ObjectKey,
256 model: &T,
257 ) -> Vec<&'static str> {
258 self.change_tracker.changed_fields(key, model)
259 }
260
261 pub fn check_cycles(&self) -> Result<(), UowError> {
265 let mut visited = HashSet::new();
267 let mut rec_stack = HashSet::new();
268 let mut cycle_path = Vec::new();
269
270 for table in &self.tables {
271 if !visited.contains(table)
272 && self.detect_cycle_dfs(table, &mut visited, &mut rec_stack, &mut cycle_path)
273 {
274 return Err(UowError::CycleDetected { tables: cycle_path });
275 }
276 }
277
278 Ok(())
279 }
280
281 fn detect_cycle_dfs(
283 &self,
284 table: &'static str,
285 visited: &mut HashSet<&'static str>,
286 rec_stack: &mut HashSet<&'static str>,
287 path: &mut Vec<&'static str>,
288 ) -> bool {
289 visited.insert(table);
290 rec_stack.insert(table);
291 path.push(table);
292
293 if let Some(deps) = self.table_dependencies.get(table) {
294 for dep in deps {
295 if !self.tables.contains(dep) {
297 continue;
298 }
299
300 if !visited.contains(dep) {
301 if self.detect_cycle_dfs(dep, visited, rec_stack, path) {
302 return true;
303 }
304 } else if rec_stack.contains(dep) {
305 path.push(dep);
307 return true;
308 }
309 }
310 }
311
312 rec_stack.remove(table);
313 path.pop();
314 false
315 }
316
317 pub fn compute_flush_plan(&self) -> Result<FlushPlan, UowError> {
325 self.check_cycles()?;
327
328 let mut ops = Vec::new();
330
331 for insert in &self.new_objects {
333 ops.push(PendingOp::Insert {
334 key: insert.key,
335 table: insert.table,
336 columns: insert.columns.clone(),
337 values: insert.values.clone(),
338 });
339 }
340
341 for update in &self.dirty_objects {
343 ops.push(PendingOp::Update {
344 key: update.key,
345 table: update.table,
346 pk_columns: update.pk_columns.clone(),
347 pk_values: update.pk_values.clone(),
348 set_columns: update.set_columns.clone(),
349 set_values: update.set_values.clone(),
350 });
351 }
352
353 for delete in &self.deleted_objects {
355 ops.push(PendingOp::Delete {
356 key: delete.key,
357 table: delete.table,
358 pk_columns: delete.pk_columns.clone(),
359 pk_values: delete.pk_values.clone(),
360 });
361 }
362
363 Ok(self.orderer.order(ops))
365 }
366
367 pub fn clear(&mut self) {
371 self.new_objects.clear();
372 self.dirty_objects.clear();
373 self.deleted_objects.clear();
374 self.change_tracker.clear_all();
375 }
376
377 #[must_use]
379 pub fn has_changes(&self) -> bool {
380 !self.new_objects.is_empty()
381 || !self.dirty_objects.is_empty()
382 || !self.deleted_objects.is_empty()
383 }
384
385 #[must_use]
387 pub fn pending_count(&self) -> PendingCounts {
388 PendingCounts {
389 new: self.new_objects.len(),
390 dirty: self.dirty_objects.len(),
391 deleted: self.deleted_objects.len(),
392 }
393 }
394
395 #[must_use]
397 pub fn change_tracker(&self) -> &ChangeTracker {
398 &self.change_tracker
399 }
400
401 pub fn change_tracker_mut(&mut self) -> &mut ChangeTracker {
403 &mut self.change_tracker
404 }
405}
406
407#[derive(Debug, Clone, Copy, Default)]
409pub struct PendingCounts {
410 pub new: usize,
412 pub dirty: usize,
414 pub deleted: usize,
416}
417
418impl PendingCounts {
419 #[must_use]
421 pub fn total(&self) -> usize {
422 self.new + self.dirty + self.deleted
423 }
424
425 #[must_use]
427 pub fn is_empty(&self) -> bool {
428 self.new == 0 && self.dirty == 0 && self.deleted == 0
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435 use serde::{Deserialize, Serialize};
436 use sqlmodel_core::{FieldInfo, Row, SqlType};
437
438 #[derive(Debug, Clone, Serialize, Deserialize)]
439 struct Team {
440 id: Option<i64>,
441 name: String,
442 }
443
444 impl Model for Team {
445 const TABLE_NAME: &'static str = "teams";
446 const PRIMARY_KEY: &'static [&'static str] = &["id"];
447
448 fn fields() -> &'static [FieldInfo] {
449 static FIELDS: &[FieldInfo] = &[
450 FieldInfo::new("id", "id", SqlType::BigInt).primary_key(true),
451 FieldInfo::new("name", "name", SqlType::Text),
452 ];
453 FIELDS
454 }
455
456 fn to_row(&self) -> Vec<(&'static str, Value)> {
457 vec![
458 ("id", self.id.map_or(Value::Null, Value::BigInt)),
459 ("name", Value::Text(self.name.clone())),
460 ]
461 }
462
463 fn from_row(_row: &Row) -> sqlmodel_core::Result<Self> {
464 Ok(Self {
465 id: None,
466 name: String::new(),
467 })
468 }
469
470 fn primary_key_value(&self) -> Vec<Value> {
471 vec![self.id.map_or(Value::Null, Value::BigInt)]
472 }
473
474 fn is_new(&self) -> bool {
475 self.id.is_none()
476 }
477 }
478
479 #[derive(Debug, Clone, Serialize, Deserialize)]
480 struct Hero {
481 id: Option<i64>,
482 name: String,
483 team_id: Option<i64>,
484 }
485
486 impl Model for Hero {
487 const TABLE_NAME: &'static str = "heroes";
488 const PRIMARY_KEY: &'static [&'static str] = &["id"];
489
490 fn fields() -> &'static [FieldInfo] {
491 static FIELDS: &[FieldInfo] = &[
492 FieldInfo::new("id", "id", SqlType::BigInt).primary_key(true),
493 FieldInfo::new("name", "name", SqlType::Text),
494 FieldInfo::new("team_id", "team_id", SqlType::BigInt)
495 .nullable(true)
496 .foreign_key("teams.id"),
497 ];
498 FIELDS
499 }
500
501 fn to_row(&self) -> Vec<(&'static str, Value)> {
502 vec![
503 ("id", self.id.map_or(Value::Null, Value::BigInt)),
504 ("name", Value::Text(self.name.clone())),
505 ("team_id", self.team_id.map_or(Value::Null, Value::BigInt)),
506 ]
507 }
508
509 fn from_row(_row: &Row) -> sqlmodel_core::Result<Self> {
510 Ok(Self {
511 id: None,
512 name: String::new(),
513 team_id: None,
514 })
515 }
516
517 fn primary_key_value(&self) -> Vec<Value> {
518 vec![self.id.map_or(Value::Null, Value::BigInt)]
519 }
520
521 fn is_new(&self) -> bool {
522 self.id.is_none()
523 }
524 }
525
526 fn make_key<T: Model + 'static>(pk: i64) -> ObjectKey {
527 ObjectKey::from_pk::<T>(&[Value::BigInt(pk)])
528 }
529
530 #[test]
531 fn test_track_new_object() {
532 let mut uow = UnitOfWork::new();
533
534 let team = Team {
535 id: Some(1),
536 name: "Avengers".to_string(),
537 };
538 let key = make_key::<Team>(1);
539
540 uow.track_new(&team, key);
541
542 assert!(uow.has_changes());
543 assert_eq!(uow.pending_count().new, 1);
544 assert_eq!(uow.pending_count().dirty, 0);
545 assert_eq!(uow.pending_count().deleted, 0);
546 }
547
548 #[test]
549 fn test_track_dirty_object() {
550 let mut uow = UnitOfWork::new();
551
552 let hero = Hero {
553 id: Some(1),
554 name: "Spider-Man".to_string(),
555 team_id: Some(1),
556 };
557 let key = make_key::<Hero>(1);
558
559 uow.track_dirty(&hero, key, vec!["name"]);
560
561 assert!(uow.has_changes());
562 assert_eq!(uow.pending_count().dirty, 1);
563 }
564
565 #[test]
566 fn test_track_deleted_object() {
567 let mut uow = UnitOfWork::new();
568
569 let team = Team {
570 id: Some(1),
571 name: "Avengers".to_string(),
572 };
573 let key = make_key::<Team>(1);
574
575 uow.track_deleted(&team, key);
576
577 assert!(uow.has_changes());
578 assert_eq!(uow.pending_count().deleted, 1);
579 }
580
581 #[test]
582 fn test_compute_flush_plan_orders_correctly() {
583 let mut uow = UnitOfWork::new();
584 uow.register_model::<Team>();
585 uow.register_model::<Hero>();
586
587 let hero = Hero {
589 id: Some(1),
590 name: "Spider-Man".to_string(),
591 team_id: Some(1),
592 };
593 let team = Team {
594 id: Some(1),
595 name: "Avengers".to_string(),
596 };
597
598 uow.track_new(&hero, make_key::<Hero>(1));
599 uow.track_new(&team, make_key::<Team>(1));
600
601 let plan = uow.compute_flush_plan().unwrap();
602
603 assert_eq!(plan.inserts[0].table(), "teams");
605 assert_eq!(plan.inserts[1].table(), "heroes");
606 }
607
608 #[test]
609 fn test_clear_removes_all_tracked() {
610 let mut uow = UnitOfWork::new();
611
612 let team = Team {
613 id: Some(1),
614 name: "Avengers".to_string(),
615 };
616 uow.track_new(&team, make_key::<Team>(1));
617 uow.track_deleted(&team, make_key::<Team>(2));
618
619 assert!(uow.has_changes());
620
621 uow.clear();
622
623 assert!(!uow.has_changes());
624 assert!(uow.pending_count().is_empty());
625 }
626
627 #[test]
628 fn test_snapshot_and_dirty_detection() {
629 let mut uow = UnitOfWork::new();
630
631 let hero = Hero {
632 id: Some(1),
633 name: "Spider-Man".to_string(),
634 team_id: Some(1),
635 };
636 let key = make_key::<Hero>(1);
637
638 uow.snapshot(key, &hero);
640
641 assert!(!uow.is_dirty(&key, &hero));
643
644 let modified = Hero {
646 id: Some(1),
647 name: "Peter Parker".to_string(),
648 team_id: Some(1),
649 };
650
651 assert!(uow.is_dirty(&key, &modified));
653
654 let changed = uow.changed_fields(&key, &modified);
656 assert_eq!(changed, vec!["name"]);
657 }
658
659 #[test]
660 fn test_track_dirty_auto() {
661 let mut uow = UnitOfWork::new();
662
663 let hero = Hero {
664 id: Some(1),
665 name: "Spider-Man".to_string(),
666 team_id: Some(1),
667 };
668 let key = make_key::<Hero>(1);
669
670 uow.snapshot(key, &hero);
672
673 let modified = Hero {
675 id: Some(1),
676 name: "Peter Parker".to_string(),
677 team_id: Some(2),
678 };
679
680 uow.track_dirty_auto(&modified, key);
682
683 assert_eq!(uow.pending_count().dirty, 1);
684 }
685
686 #[test]
687 fn test_no_cycle_in_normal_hierarchy() {
688 let mut uow = UnitOfWork::new();
689 uow.register_model::<Team>();
690 uow.register_model::<Hero>();
691
692 assert!(uow.check_cycles().is_ok());
694 }
695
696 #[test]
697 fn test_pending_counts() {
698 let counts = PendingCounts {
699 new: 3,
700 dirty: 2,
701 deleted: 1,
702 };
703
704 assert_eq!(counts.total(), 6);
705 assert!(!counts.is_empty());
706
707 let empty = PendingCounts::default();
708 assert!(empty.is_empty());
709 assert_eq!(empty.total(), 0);
710 }
711
712 #[test]
713 fn test_empty_dirty_not_tracked() {
714 let mut uow = UnitOfWork::new();
715
716 let hero = Hero {
717 id: Some(1),
718 name: "Spider-Man".to_string(),
719 team_id: Some(1),
720 };
721 let key = make_key::<Hero>(1);
722
723 uow.track_dirty(&hero, key, vec![]);
725
726 assert!(!uow.has_changes());
727 }
728}