Skip to main content

sqlmodel_session/
unit_of_work.rs

1//! Unit of Work pattern implementation for SQLModel Session.
2//!
3//! The Unit of Work pattern tracks all changes made during a session and
4//! flushes them atomically in the correct dependency order.
5//!
6//! # Overview
7//!
8//! The Unit of Work:
9//! - Tracks new objects to INSERT
10//! - Tracks modified (dirty) objects to UPDATE
11//! - Tracks deleted objects to DELETE
12//! - Determines flush order based on foreign key dependencies
13//! - Detects dependency cycles and reports errors
14//! - Executes all changes in a single atomic transaction
15//!
16//! # Example
17//!
18//! ```ignore
19//! let mut uow = UnitOfWork::new();
20//!
21//! // Register models (extracts FK dependencies)
22//! uow.register_model::<Team>();
23//! uow.register_model::<Hero>();
24//!
25//! // Track changes
26//! uow.track_new(&team, &team_key);
27//! uow.track_new(&hero, &hero_key);
28//! uow.track_dirty(&existing_hero, &hero_key);
29//! uow.track_deleted(&old_team, &old_team_key);
30//!
31//! // Compute flush plan (checks for cycles)
32//! let plan = uow.compute_flush_plan()?;
33//!
34//! // Execute (in a transaction)
35//! plan.execute(&cx, &conn).await?;
36//! ```
37
38use crate::ObjectKey;
39use crate::change_tracker::ChangeTracker;
40use crate::flush::{FlushOrderer, FlushPlan, PendingOp};
41use serde::Serialize;
42use sqlmodel_core::{Error, Model, Value};
43use std::collections::{HashMap, HashSet};
44
45/// Tracks and manages all pending changes in a session.
46///
47/// The Unit of Work is responsible for:
48/// - Maintaining the set of new, dirty, and deleted objects
49/// - Computing the correct flush order based on FK dependencies
50/// - Detecting dependency cycles before flush
51#[derive(Default)]
52pub struct UnitOfWork {
53    /// Objects to be inserted (new).
54    new_objects: Vec<TrackedInsert>,
55
56    /// Objects that have been modified (dirty).
57    dirty_objects: Vec<TrackedUpdate>,
58
59    /// Objects to be deleted.
60    deleted_objects: Vec<TrackedDelete>,
61
62    /// Change tracker for dirty detection.
63    change_tracker: ChangeTracker,
64
65    /// Flush orderer for dependency-based ordering.
66    orderer: FlushOrderer,
67
68    /// Tables we've seen (for cycle detection).
69    tables: HashSet<&'static str>,
70
71    /// Table -> tables it depends on.
72    table_dependencies: HashMap<&'static str, Vec<&'static str>>,
73}
74
75/// A tracked object pending insertion.
76struct TrackedInsert {
77    key: ObjectKey,
78    table: &'static str,
79    columns: Vec<&'static str>,
80    values: Vec<Value>,
81}
82
83/// A tracked object pending update.
84struct TrackedUpdate {
85    key: ObjectKey,
86    table: &'static str,
87    pk_columns: Vec<&'static str>,
88    pk_values: Vec<Value>,
89    set_columns: Vec<&'static str>,
90    set_values: Vec<Value>,
91}
92
93/// A tracked object pending deletion.
94struct TrackedDelete {
95    key: ObjectKey,
96    table: &'static str,
97    pk_columns: Vec<&'static str>,
98    pk_values: Vec<Value>,
99}
100
101/// Error type for Unit of Work operations.
102#[derive(Debug, Clone)]
103pub enum UowError {
104    /// A dependency cycle was detected between tables.
105    CycleDetected {
106        /// Tables involved in the cycle.
107        tables: Vec<&'static str>,
108    },
109    /// An object was already tracked.
110    AlreadyTracked {
111        /// The object key.
112        key: ObjectKey,
113        /// The tracking state (new, dirty, deleted).
114        state: &'static str,
115    },
116}
117
118impl std::fmt::Display for UowError {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        match self {
121            UowError::CycleDetected { tables } => {
122                write!(f, "Dependency cycle detected: {}", tables.join(" -> "))
123            }
124            UowError::AlreadyTracked { key, state } => {
125                write!(f, "Object {:?} already tracked as {}", key, state)
126            }
127        }
128    }
129}
130
131impl std::error::Error for UowError {}
132
133impl From<UowError> for Error {
134    fn from(e: UowError) -> Self {
135        Error::Custom(e.to_string())
136    }
137}
138
139impl UnitOfWork {
140    /// Create a new empty Unit of Work.
141    #[must_use]
142    pub fn new() -> Self {
143        Self::default()
144    }
145
146    /// Register a model type for dependency tracking.
147    ///
148    /// This extracts foreign key relationships from the model's metadata
149    /// and registers them for flush ordering.
150    pub fn register_model<T: Model>(&mut self) {
151        self.orderer.register_model::<T>();
152
153        let table = T::TABLE_NAME;
154        self.tables.insert(table);
155
156        // Extract FK dependencies
157        let deps: Vec<&'static str> = T::fields()
158            .iter()
159            .filter_map(|f| f.foreign_key)
160            .filter_map(|fk| fk.split('.').next())
161            .collect();
162
163        self.table_dependencies.insert(table, deps);
164    }
165
166    /// Track a new object for insertion.
167    ///
168    /// The object will be INSERTed during flush.
169    pub fn track_new<T: Model + Serialize>(&mut self, model: &T, key: ObjectKey) {
170        let row = model.to_row();
171        let columns: Vec<&'static str> = row.iter().map(|(col, _)| *col).collect();
172        let values: Vec<Value> = row.into_iter().map(|(_, val)| val).collect();
173
174        self.new_objects.push(TrackedInsert {
175            key,
176            table: T::TABLE_NAME,
177            columns,
178            values,
179        });
180    }
181
182    /// Track a dirty object for update.
183    ///
184    /// The object will be UPDATEd during flush (only changed columns).
185    pub fn track_dirty<T: Model + Serialize>(
186        &mut self,
187        model: &T,
188        key: ObjectKey,
189        changed_columns: Vec<&'static str>,
190    ) {
191        if changed_columns.is_empty() {
192            return;
193        }
194
195        let row = model.to_row();
196        let row_map: HashMap<&str, Value> = row.into_iter().collect();
197
198        let pk_columns: Vec<&'static str> = T::PRIMARY_KEY.to_vec();
199        let pk_values = model.primary_key_value();
200
201        let set_columns = changed_columns;
202        let set_values: Vec<Value> = set_columns
203            .iter()
204            .filter_map(|col| row_map.get(*col).cloned())
205            .collect();
206
207        self.dirty_objects.push(TrackedUpdate {
208            key,
209            table: T::TABLE_NAME,
210            pk_columns,
211            pk_values,
212            set_columns,
213            set_values,
214        });
215    }
216
217    /// Track a dirty object for update (auto-detect changed fields).
218    ///
219    /// Uses the change tracker to determine which fields changed.
220    pub fn track_dirty_auto<T: Model + Serialize>(&mut self, model: &T, key: ObjectKey) {
221        let changed = self.change_tracker.changed_fields(&key, model);
222        if !changed.is_empty() {
223            self.track_dirty(model, key, changed);
224        }
225    }
226
227    /// Track an object for deletion.
228    ///
229    /// The object will be DELETEd during flush.
230    pub fn track_deleted<T: Model>(&mut self, model: &T, key: ObjectKey) {
231        let pk_columns: Vec<&'static str> = T::PRIMARY_KEY.to_vec();
232        let pk_values = model.primary_key_value();
233
234        self.deleted_objects.push(TrackedDelete {
235            key,
236            table: T::TABLE_NAME,
237            pk_columns,
238            pk_values,
239        });
240    }
241
242    /// Take a snapshot of an object for later dirty detection.
243    pub fn snapshot<T: Model + Serialize>(&mut self, key: ObjectKey, model: &T) {
244        self.change_tracker.snapshot(key, model);
245    }
246
247    /// Check if an object is dirty (has changed since snapshot).
248    pub fn is_dirty<T: Model + Serialize>(&self, key: &ObjectKey, model: &T) -> bool {
249        self.change_tracker.is_dirty(key, model)
250    }
251
252    /// Get the changed fields for an object.
253    pub fn changed_fields<T: Model + Serialize>(
254        &self,
255        key: &ObjectKey,
256        model: &T,
257    ) -> Vec<&'static str> {
258        self.change_tracker.changed_fields(key, model)
259    }
260
261    /// Check for dependency cycles in the registered tables.
262    ///
263    /// Returns `Err(UowError::CycleDetected)` if a cycle is found.
264    pub fn check_cycles(&self) -> Result<(), UowError> {
265        // Use DFS to detect cycles
266        let mut visited = HashSet::new();
267        let mut rec_stack = HashSet::new();
268        let mut cycle_path = Vec::new();
269
270        for table in &self.tables {
271            if !visited.contains(table)
272                && self.detect_cycle_dfs(table, &mut visited, &mut rec_stack, &mut cycle_path)
273            {
274                return Err(UowError::CycleDetected { tables: cycle_path });
275            }
276        }
277
278        Ok(())
279    }
280
281    /// DFS helper for cycle detection.
282    fn detect_cycle_dfs(
283        &self,
284        table: &'static str,
285        visited: &mut HashSet<&'static str>,
286        rec_stack: &mut HashSet<&'static str>,
287        path: &mut Vec<&'static str>,
288    ) -> bool {
289        visited.insert(table);
290        rec_stack.insert(table);
291        path.push(table);
292
293        if let Some(deps) = self.table_dependencies.get(table) {
294            for dep in deps {
295                // Only check tables we know about
296                if !self.tables.contains(dep) {
297                    continue;
298                }
299
300                if !visited.contains(dep) {
301                    if self.detect_cycle_dfs(dep, visited, rec_stack, path) {
302                        return true;
303                    }
304                } else if rec_stack.contains(dep) {
305                    // Found cycle - add the start of cycle to complete it
306                    path.push(dep);
307                    return true;
308                }
309            }
310        }
311
312        rec_stack.remove(table);
313        path.pop();
314        false
315    }
316
317    /// Compute the flush plan.
318    ///
319    /// This checks for cycles and orders operations by dependencies.
320    ///
321    /// # Errors
322    ///
323    /// Returns `Err` if a dependency cycle is detected.
324    pub fn compute_flush_plan(&self) -> Result<FlushPlan, UowError> {
325        // Check for cycles first
326        self.check_cycles()?;
327
328        // Build pending ops
329        let mut ops = Vec::new();
330
331        // Add inserts
332        for insert in &self.new_objects {
333            ops.push(PendingOp::Insert {
334                key: insert.key,
335                table: insert.table,
336                columns: insert.columns.clone(),
337                values: insert.values.clone(),
338            });
339        }
340
341        // Add updates
342        for update in &self.dirty_objects {
343            ops.push(PendingOp::Update {
344                key: update.key,
345                table: update.table,
346                pk_columns: update.pk_columns.clone(),
347                pk_values: update.pk_values.clone(),
348                set_columns: update.set_columns.clone(),
349                set_values: update.set_values.clone(),
350            });
351        }
352
353        // Add deletes
354        for delete in &self.deleted_objects {
355            ops.push(PendingOp::Delete {
356                key: delete.key,
357                table: delete.table,
358                pk_columns: delete.pk_columns.clone(),
359                pk_values: delete.pk_values.clone(),
360            });
361        }
362
363        // Order by dependencies
364        Ok(self.orderer.order(ops))
365    }
366
367    /// Clear all tracked changes.
368    ///
369    /// Call this after a successful commit.
370    pub fn clear(&mut self) {
371        self.new_objects.clear();
372        self.dirty_objects.clear();
373        self.deleted_objects.clear();
374        self.change_tracker.clear_all();
375    }
376
377    /// Check if there are any pending changes.
378    #[must_use]
379    pub fn has_changes(&self) -> bool {
380        !self.new_objects.is_empty()
381            || !self.dirty_objects.is_empty()
382            || !self.deleted_objects.is_empty()
383    }
384
385    /// Get the count of pending operations.
386    #[must_use]
387    pub fn pending_count(&self) -> PendingCounts {
388        PendingCounts {
389            new: self.new_objects.len(),
390            dirty: self.dirty_objects.len(),
391            deleted: self.deleted_objects.len(),
392        }
393    }
394
395    /// Get a reference to the change tracker.
396    #[must_use]
397    pub fn change_tracker(&self) -> &ChangeTracker {
398        &self.change_tracker
399    }
400
401    /// Get a mutable reference to the change tracker.
402    pub fn change_tracker_mut(&mut self) -> &mut ChangeTracker {
403        &mut self.change_tracker
404    }
405}
406
407/// Count of pending operations by type.
408#[derive(Debug, Clone, Copy, Default)]
409pub struct PendingCounts {
410    /// Objects pending INSERT.
411    pub new: usize,
412    /// Objects pending UPDATE.
413    pub dirty: usize,
414    /// Objects pending DELETE.
415    pub deleted: usize,
416}
417
418impl PendingCounts {
419    /// Total number of pending operations.
420    #[must_use]
421    pub fn total(&self) -> usize {
422        self.new + self.dirty + self.deleted
423    }
424
425    /// Check if there are no pending operations.
426    #[must_use]
427    pub fn is_empty(&self) -> bool {
428        self.new == 0 && self.dirty == 0 && self.deleted == 0
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435    use serde::{Deserialize, Serialize};
436    use sqlmodel_core::{FieldInfo, Row, SqlType};
437
438    #[derive(Debug, Clone, Serialize, Deserialize)]
439    struct Team {
440        id: Option<i64>,
441        name: String,
442    }
443
444    impl Model for Team {
445        const TABLE_NAME: &'static str = "teams";
446        const PRIMARY_KEY: &'static [&'static str] = &["id"];
447
448        fn fields() -> &'static [FieldInfo] {
449            static FIELDS: &[FieldInfo] = &[
450                FieldInfo::new("id", "id", SqlType::BigInt).primary_key(true),
451                FieldInfo::new("name", "name", SqlType::Text),
452            ];
453            FIELDS
454        }
455
456        fn to_row(&self) -> Vec<(&'static str, Value)> {
457            vec![
458                ("id", self.id.map_or(Value::Null, Value::BigInt)),
459                ("name", Value::Text(self.name.clone())),
460            ]
461        }
462
463        fn from_row(_row: &Row) -> sqlmodel_core::Result<Self> {
464            Ok(Self {
465                id: None,
466                name: String::new(),
467            })
468        }
469
470        fn primary_key_value(&self) -> Vec<Value> {
471            vec![self.id.map_or(Value::Null, Value::BigInt)]
472        }
473
474        fn is_new(&self) -> bool {
475            self.id.is_none()
476        }
477    }
478
479    #[derive(Debug, Clone, Serialize, Deserialize)]
480    struct Hero {
481        id: Option<i64>,
482        name: String,
483        team_id: Option<i64>,
484    }
485
486    impl Model for Hero {
487        const TABLE_NAME: &'static str = "heroes";
488        const PRIMARY_KEY: &'static [&'static str] = &["id"];
489
490        fn fields() -> &'static [FieldInfo] {
491            static FIELDS: &[FieldInfo] = &[
492                FieldInfo::new("id", "id", SqlType::BigInt).primary_key(true),
493                FieldInfo::new("name", "name", SqlType::Text),
494                FieldInfo::new("team_id", "team_id", SqlType::BigInt)
495                    .nullable(true)
496                    .foreign_key("teams.id"),
497            ];
498            FIELDS
499        }
500
501        fn to_row(&self) -> Vec<(&'static str, Value)> {
502            vec![
503                ("id", self.id.map_or(Value::Null, Value::BigInt)),
504                ("name", Value::Text(self.name.clone())),
505                ("team_id", self.team_id.map_or(Value::Null, Value::BigInt)),
506            ]
507        }
508
509        fn from_row(_row: &Row) -> sqlmodel_core::Result<Self> {
510            Ok(Self {
511                id: None,
512                name: String::new(),
513                team_id: None,
514            })
515        }
516
517        fn primary_key_value(&self) -> Vec<Value> {
518            vec![self.id.map_or(Value::Null, Value::BigInt)]
519        }
520
521        fn is_new(&self) -> bool {
522            self.id.is_none()
523        }
524    }
525
526    fn make_key<T: Model + 'static>(pk: i64) -> ObjectKey {
527        ObjectKey::from_pk::<T>(&[Value::BigInt(pk)])
528    }
529
530    #[test]
531    fn test_track_new_object() {
532        let mut uow = UnitOfWork::new();
533
534        let team = Team {
535            id: Some(1),
536            name: "Avengers".to_string(),
537        };
538        let key = make_key::<Team>(1);
539
540        uow.track_new(&team, key);
541
542        assert!(uow.has_changes());
543        assert_eq!(uow.pending_count().new, 1);
544        assert_eq!(uow.pending_count().dirty, 0);
545        assert_eq!(uow.pending_count().deleted, 0);
546    }
547
548    #[test]
549    fn test_track_dirty_object() {
550        let mut uow = UnitOfWork::new();
551
552        let hero = Hero {
553            id: Some(1),
554            name: "Spider-Man".to_string(),
555            team_id: Some(1),
556        };
557        let key = make_key::<Hero>(1);
558
559        uow.track_dirty(&hero, key, vec!["name"]);
560
561        assert!(uow.has_changes());
562        assert_eq!(uow.pending_count().dirty, 1);
563    }
564
565    #[test]
566    fn test_track_deleted_object() {
567        let mut uow = UnitOfWork::new();
568
569        let team = Team {
570            id: Some(1),
571            name: "Avengers".to_string(),
572        };
573        let key = make_key::<Team>(1);
574
575        uow.track_deleted(&team, key);
576
577        assert!(uow.has_changes());
578        assert_eq!(uow.pending_count().deleted, 1);
579    }
580
581    #[test]
582    fn test_compute_flush_plan_orders_correctly() {
583        let mut uow = UnitOfWork::new();
584        uow.register_model::<Team>();
585        uow.register_model::<Hero>();
586
587        // Add hero first (has FK to team), then team
588        let hero = Hero {
589            id: Some(1),
590            name: "Spider-Man".to_string(),
591            team_id: Some(1),
592        };
593        let team = Team {
594            id: Some(1),
595            name: "Avengers".to_string(),
596        };
597
598        uow.track_new(&hero, make_key::<Hero>(1));
599        uow.track_new(&team, make_key::<Team>(1));
600
601        let plan = uow.compute_flush_plan().unwrap();
602
603        // Team should be inserted first (no deps)
604        assert_eq!(plan.inserts[0].table(), "teams");
605        assert_eq!(plan.inserts[1].table(), "heroes");
606    }
607
608    #[test]
609    fn test_clear_removes_all_tracked() {
610        let mut uow = UnitOfWork::new();
611
612        let team = Team {
613            id: Some(1),
614            name: "Avengers".to_string(),
615        };
616        uow.track_new(&team, make_key::<Team>(1));
617        uow.track_deleted(&team, make_key::<Team>(2));
618
619        assert!(uow.has_changes());
620
621        uow.clear();
622
623        assert!(!uow.has_changes());
624        assert!(uow.pending_count().is_empty());
625    }
626
627    #[test]
628    fn test_snapshot_and_dirty_detection() {
629        let mut uow = UnitOfWork::new();
630
631        let hero = Hero {
632            id: Some(1),
633            name: "Spider-Man".to_string(),
634            team_id: Some(1),
635        };
636        let key = make_key::<Hero>(1);
637
638        // Take snapshot
639        uow.snapshot(key, &hero);
640
641        // Not dirty yet
642        assert!(!uow.is_dirty(&key, &hero));
643
644        // Modify
645        let modified = Hero {
646            id: Some(1),
647            name: "Peter Parker".to_string(),
648            team_id: Some(1),
649        };
650
651        // Now dirty
652        assert!(uow.is_dirty(&key, &modified));
653
654        // Check which fields changed
655        let changed = uow.changed_fields(&key, &modified);
656        assert_eq!(changed, vec!["name"]);
657    }
658
659    #[test]
660    fn test_track_dirty_auto() {
661        let mut uow = UnitOfWork::new();
662
663        let hero = Hero {
664            id: Some(1),
665            name: "Spider-Man".to_string(),
666            team_id: Some(1),
667        };
668        let key = make_key::<Hero>(1);
669
670        // Snapshot original
671        uow.snapshot(key, &hero);
672
673        // Modify
674        let modified = Hero {
675            id: Some(1),
676            name: "Peter Parker".to_string(),
677            team_id: Some(2),
678        };
679
680        // Auto-track dirty
681        uow.track_dirty_auto(&modified, key);
682
683        assert_eq!(uow.pending_count().dirty, 1);
684    }
685
686    #[test]
687    fn test_no_cycle_in_normal_hierarchy() {
688        let mut uow = UnitOfWork::new();
689        uow.register_model::<Team>();
690        uow.register_model::<Hero>();
691
692        // Hero -> Team is a valid hierarchy (no cycle)
693        assert!(uow.check_cycles().is_ok());
694    }
695
696    #[test]
697    fn test_pending_counts() {
698        let counts = PendingCounts {
699            new: 3,
700            dirty: 2,
701            deleted: 1,
702        };
703
704        assert_eq!(counts.total(), 6);
705        assert!(!counts.is_empty());
706
707        let empty = PendingCounts::default();
708        assert!(empty.is_empty());
709        assert_eq!(empty.total(), 0);
710    }
711
712    #[test]
713    fn test_empty_dirty_not_tracked() {
714        let mut uow = UnitOfWork::new();
715
716        let hero = Hero {
717            id: Some(1),
718            name: "Spider-Man".to_string(),
719            team_id: Some(1),
720        };
721        let key = make_key::<Hero>(1);
722
723        // Empty changed columns - should not track
724        uow.track_dirty(&hero, key, vec![]);
725
726        assert!(!uow.has_changes());
727    }
728}