Skip to main content

sqlmodel_session/
change_tracker.rs

1//! Change tracking and dirty detection for SQLModel Session.
2//!
3//! This module provides snapshot-based change tracking to detect when objects
4//! have been modified since they were loaded from the database.
5
6use crate::ObjectKey;
7use serde::Serialize;
8use sqlmodel_core::Model;
9use std::collections::HashMap;
10use std::time::Instant;
11
12/// Snapshot of an object's state at a point in time.
13#[derive(Debug)]
14pub struct ObjectSnapshot {
15    /// Serialized original state (JSON bytes).
16    data: Vec<u8>,
17    /// Timestamp when snapshot was taken.
18    taken_at: Instant,
19}
20
21impl ObjectSnapshot {
22    /// Create a new snapshot from serialized data.
23    pub fn new(data: Vec<u8>) -> Self {
24        Self {
25            data,
26            taken_at: Instant::now(),
27        }
28    }
29
30    /// Get the snapshot data.
31    pub fn data(&self) -> &[u8] {
32        &self.data
33    }
34
35    /// Get the timestamp when the snapshot was taken.
36    pub fn taken_at(&self) -> Instant {
37        self.taken_at
38    }
39}
40
41/// Tracks changes to objects in the session.
42///
43/// Uses snapshot comparison to detect when objects have been modified.
44pub struct ChangeTracker {
45    /// Original snapshots by object key.
46    snapshots: HashMap<ObjectKey, ObjectSnapshot>,
47}
48
49impl ChangeTracker {
50    /// Create a new empty change tracker.
51    pub fn new() -> Self {
52        Self {
53            snapshots: HashMap::new(),
54        }
55    }
56
57    /// Take a snapshot of an object.
58    ///
59    /// This stores the serialized state of the object for later comparison.
60    #[tracing::instrument(level = "trace", skip(self, obj))]
61    pub fn snapshot<T: Model + Serialize>(&mut self, key: ObjectKey, obj: &T) {
62        let data = match serde_json::to_vec(obj) {
63            Ok(d) => d,
64            Err(e) => {
65                tracing::warn!(
66                    model = std::any::type_name::<T>(),
67                    error = %e,
68                    "Snapshot serialization failed, storing empty snapshot"
69                );
70                Vec::new()
71            }
72        };
73        tracing::trace!(
74            model = std::any::type_name::<T>(),
75            pk_hash = key.pk_hash(),
76            snapshot_bytes = data.len(),
77            "Taking object snapshot"
78        );
79        self.snapshots.insert(key, ObjectSnapshot::new(data));
80    }
81
82    /// Take a snapshot from raw bytes.
83    pub fn snapshot_raw(&mut self, key: ObjectKey, data: Vec<u8>) {
84        self.snapshots.insert(key, ObjectSnapshot::new(data));
85    }
86
87    /// Check if an object has changed since its snapshot.
88    ///
89    /// Returns `true` if:
90    /// - The object has no snapshot (treated as dirty)
91    /// - The current state differs from the snapshot
92    #[tracing::instrument(level = "trace", skip(self, obj))]
93    pub fn is_dirty<T: Model + Serialize>(&self, key: &ObjectKey, obj: &T) -> bool {
94        let Some(snapshot) = self.snapshots.get(key) else {
95            tracing::trace!(
96                pk_hash = key.pk_hash(),
97                dirty = true,
98                "No snapshot - treating as dirty"
99            );
100            return true;
101        };
102
103        let current = match serde_json::to_vec(obj) {
104            Ok(d) => d,
105            Err(e) => {
106                tracing::warn!(
107                    model = std::any::type_name::<T>(),
108                    error = %e,
109                    "Dirty check serialization failed, treating as dirty"
110                );
111                return true;
112            }
113        };
114        let dirty = current != snapshot.data;
115        tracing::trace!(pk_hash = key.pk_hash(), dirty = dirty, "Dirty check result");
116        dirty
117    }
118
119    /// Check if raw bytes match the snapshot.
120    pub fn is_dirty_raw(&self, key: &ObjectKey, current: &[u8]) -> bool {
121        let Some(snapshot) = self.snapshots.get(key) else {
122            return true;
123        };
124        current != snapshot.data
125    }
126
127    /// Get changed fields between snapshot and current state.
128    ///
129    /// Returns a list of field names that have different values.
130    #[tracing::instrument(level = "debug", skip(self, obj))]
131    pub fn changed_fields<T: Model + Serialize>(
132        &self,
133        key: &ObjectKey,
134        obj: &T,
135    ) -> Vec<&'static str> {
136        let Some(snapshot) = self.snapshots.get(key) else {
137            // No snapshot = all fields are "changed"
138            let fields: Vec<&'static str> = T::fields().iter().map(|f| f.name).collect();
139            tracing::debug!(
140                model = std::any::type_name::<T>(),
141                changed_count = fields.len(),
142                "No snapshot - all fields considered changed"
143            );
144            return fields;
145        };
146
147        // Parse both as JSON objects and compare fields
148        let original: serde_json::Value = match serde_json::from_slice(&snapshot.data) {
149            Ok(v) => v,
150            Err(e) => {
151                tracing::warn!(
152                    model = std::any::type_name::<T>(),
153                    error = %e,
154                    "Snapshot deserialization failed in changed_fields, treating all as changed"
155                );
156                serde_json::Value::Null
157            }
158        };
159        let current: serde_json::Value = match serde_json::to_value(obj) {
160            Ok(v) => v,
161            Err(e) => {
162                tracing::warn!(
163                    model = std::any::type_name::<T>(),
164                    error = %e,
165                    "Current serialization failed in changed_fields, treating all as changed"
166                );
167                serde_json::Value::Null
168            }
169        };
170
171        let mut changed = Vec::new();
172        for field in T::fields() {
173            let orig_val = original.get(field.name);
174            let curr_val = current.get(field.name);
175            if orig_val != curr_val {
176                changed.push(field.name);
177            }
178        }
179
180        tracing::debug!(
181            model = std::any::type_name::<T>(),
182            changed_count = changed.len(),
183            fields = ?changed,
184            "Detected changed fields"
185        );
186        changed
187    }
188
189    /// Get changed fields from raw JSON bytes.
190    pub fn changed_fields_raw(
191        &self,
192        key: &ObjectKey,
193        current_bytes: &[u8],
194        field_names: &[&'static str],
195    ) -> Vec<&'static str> {
196        let Some(snapshot) = self.snapshots.get(key) else {
197            return field_names.to_vec();
198        };
199
200        let original: serde_json::Value = match serde_json::from_slice(&snapshot.data) {
201            Ok(v) => v,
202            Err(e) => {
203                tracing::warn!(
204                    error = %e,
205                    "Snapshot deserialization failed in changed_fields_raw, treating all as changed"
206                );
207                serde_json::Value::Null
208            }
209        };
210        let current: serde_json::Value = match serde_json::from_slice(current_bytes) {
211            Ok(v) => v,
212            Err(e) => {
213                tracing::warn!(
214                    error = %e,
215                    "Current deserialization failed in changed_fields_raw, treating all as changed"
216                );
217                serde_json::Value::Null
218            }
219        };
220
221        let mut changed = Vec::new();
222        for name in field_names {
223            let orig_val = original.get(*name);
224            let curr_val = current.get(*name);
225            if orig_val != curr_val {
226                changed.push(*name);
227            }
228        }
229        changed
230    }
231
232    /// Get detailed attribute changes between snapshot and current state.
233    ///
234    /// Returns `AttributeChange` structs with field name, old value, and new value.
235    pub fn attribute_changes<T: Model + Serialize>(
236        &self,
237        key: &ObjectKey,
238        obj: &T,
239    ) -> Vec<sqlmodel_core::AttributeChange> {
240        let Some(snapshot) = self.snapshots.get(key) else {
241            return Vec::new();
242        };
243
244        let original: serde_json::Value = match serde_json::from_slice(&snapshot.data) {
245            Ok(v) => v,
246            Err(e) => {
247                tracing::warn!(
248                    model = std::any::type_name::<T>(),
249                    error = %e,
250                    "Snapshot deserialization failed in attribute_changes, treating as empty"
251                );
252                serde_json::Value::Null
253            }
254        };
255        let current: serde_json::Value = match serde_json::to_value(obj) {
256            Ok(v) => v,
257            Err(e) => {
258                tracing::warn!(
259                    model = std::any::type_name::<T>(),
260                    error = %e,
261                    "Current serialization failed in attribute_changes, treating as empty"
262                );
263                serde_json::Value::Null
264            }
265        };
266
267        let mut changes = Vec::new();
268        for field in T::fields() {
269            let orig_val = original
270                .get(field.name)
271                .cloned()
272                .unwrap_or(serde_json::Value::Null);
273            let curr_val = current
274                .get(field.name)
275                .cloned()
276                .unwrap_or(serde_json::Value::Null);
277            if orig_val != curr_val {
278                changes.push(sqlmodel_core::AttributeChange {
279                    field_name: field.name,
280                    old_value: orig_val,
281                    new_value: curr_val,
282                });
283            }
284        }
285        changes
286    }
287
288    /// Check if a snapshot exists for the given key.
289    pub fn has_snapshot(&self, key: &ObjectKey) -> bool {
290        self.snapshots.contains_key(key)
291    }
292
293    /// Get the snapshot for a key.
294    pub fn get_snapshot(&self, key: &ObjectKey) -> Option<&ObjectSnapshot> {
295        self.snapshots.get(key)
296    }
297
298    /// Clear snapshot for a specific object.
299    ///
300    /// Call this after commit or when discarding changes.
301    pub fn clear(&mut self, key: &ObjectKey) {
302        self.snapshots.remove(key);
303    }
304
305    /// Clear all snapshots.
306    ///
307    /// Call this after commit or rollback to reset tracking state.
308    pub fn clear_all(&mut self) {
309        self.snapshots.clear();
310    }
311
312    /// Update snapshot after flush (new baseline).
313    ///
314    /// Call this after a successful flush to set the current state as the new baseline.
315    #[tracing::instrument(level = "trace", skip(self, obj))]
316    pub fn refresh<T: Model + Serialize>(&mut self, key: ObjectKey, obj: &T) {
317        tracing::trace!(pk_hash = key.pk_hash(), "Refreshing snapshot");
318        self.snapshot(key, obj);
319    }
320
321    /// Number of tracked snapshots.
322    pub fn len(&self) -> usize {
323        self.snapshots.len()
324    }
325
326    /// Check if there are no snapshots.
327    pub fn is_empty(&self) -> bool {
328        self.snapshots.is_empty()
329    }
330}
331
332impl Default for ChangeTracker {
333    fn default() -> Self {
334        Self::new()
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use serde::{Deserialize, Serialize};
342    use sqlmodel_core::{FieldInfo, Row, Value};
343
344    // Mock model for testing
345    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
346    struct TestHero {
347        id: i64,
348        name: String,
349        age: Option<i32>,
350    }
351
352    impl Model for TestHero {
353        const TABLE_NAME: &'static str = "hero";
354        const PRIMARY_KEY: &'static [&'static str] = &["id"];
355
356        fn fields() -> &'static [FieldInfo] {
357            static FIELDS: [FieldInfo; 3] = [
358                FieldInfo::new("id", "id", sqlmodel_core::SqlType::BigInt)
359                    .primary_key(true)
360                    .auto_increment(true),
361                FieldInfo::new("name", "name", sqlmodel_core::SqlType::Text),
362                FieldInfo::new("age", "age", sqlmodel_core::SqlType::Integer).nullable(true),
363            ];
364            &FIELDS
365        }
366
367        fn primary_key_value(&self) -> Vec<Value> {
368            vec![Value::BigInt(self.id)]
369        }
370
371        fn from_row(row: &Row) -> Result<Self, sqlmodel_core::Error> {
372            Ok(Self {
373                id: row.get_named("id")?,
374                name: row.get_named("name")?,
375                age: row.get_named("age")?,
376            })
377        }
378
379        fn to_row(&self) -> Vec<(&'static str, Value)> {
380            vec![
381                ("id", Value::BigInt(self.id)),
382                ("name", Value::Text(self.name.clone())),
383                ("age", self.age.map_or(Value::Null, Value::Int)),
384            ]
385        }
386
387        fn is_new(&self) -> bool {
388            false
389        }
390    }
391
392    fn make_key(id: i64) -> ObjectKey {
393        ObjectKey::from_pk::<TestHero>(&[Value::BigInt(id)])
394    }
395
396    #[test]
397    fn test_snapshot_captures_current_state() {
398        let mut tracker = ChangeTracker::new();
399        let hero = TestHero {
400            id: 1,
401            name: "Spider-Man".to_string(),
402            age: Some(25),
403        };
404        let key = make_key(1);
405
406        tracker.snapshot(key, &hero);
407
408        assert!(tracker.has_snapshot(&key));
409        let snapshot = tracker.get_snapshot(&key).unwrap();
410        assert!(!snapshot.data().is_empty());
411    }
412
413    #[test]
414    fn test_snapshot_overwrites_previous() {
415        let mut tracker = ChangeTracker::new();
416        let key = make_key(1);
417
418        let hero1 = TestHero {
419            id: 1,
420            name: "Spider-Man".to_string(),
421            age: Some(25),
422        };
423        tracker.snapshot(key, &hero1);
424        let first_data = tracker.get_snapshot(&key).unwrap().data().to_vec();
425
426        let hero2 = TestHero {
427            id: 1,
428            name: "Peter Parker".to_string(),
429            age: Some(26),
430        };
431        tracker.snapshot(key, &hero2);
432        let second_data = tracker.get_snapshot(&key).unwrap().data().to_vec();
433
434        assert_ne!(first_data, second_data);
435    }
436
437    #[test]
438    fn test_is_dirty_false_if_unchanged() {
439        let mut tracker = ChangeTracker::new();
440        let hero = TestHero {
441            id: 1,
442            name: "Spider-Man".to_string(),
443            age: Some(25),
444        };
445        let key = make_key(1);
446
447        tracker.snapshot(key, &hero);
448
449        // Same object = not dirty
450        assert!(!tracker.is_dirty(&key, &hero));
451    }
452
453    #[test]
454    fn test_is_dirty_true_if_field_changed() {
455        let mut tracker = ChangeTracker::new();
456        let hero = TestHero {
457            id: 1,
458            name: "Spider-Man".to_string(),
459            age: Some(25),
460        };
461        let key = make_key(1);
462
463        tracker.snapshot(key, &hero);
464
465        // Modify the hero
466        let modified_hero = TestHero {
467            id: 1,
468            name: "Peter Parker".to_string(),
469            age: Some(25),
470        };
471
472        assert!(tracker.is_dirty(&key, &modified_hero));
473    }
474
475    #[test]
476    fn test_is_dirty_true_if_no_snapshot() {
477        let tracker = ChangeTracker::new();
478        let hero = TestHero {
479            id: 1,
480            name: "Spider-Man".to_string(),
481            age: Some(25),
482        };
483        let key = make_key(1);
484
485        // No snapshot = dirty
486        assert!(tracker.is_dirty(&key, &hero));
487    }
488
489    #[test]
490    fn test_changed_fields_empty_if_unchanged() {
491        let mut tracker = ChangeTracker::new();
492        let hero = TestHero {
493            id: 1,
494            name: "Spider-Man".to_string(),
495            age: Some(25),
496        };
497        let key = make_key(1);
498
499        tracker.snapshot(key, &hero);
500
501        let changed = tracker.changed_fields(&key, &hero);
502        assert!(changed.is_empty());
503    }
504
505    #[test]
506    fn test_changed_fields_lists_modified() {
507        let mut tracker = ChangeTracker::new();
508        let hero = TestHero {
509            id: 1,
510            name: "Spider-Man".to_string(),
511            age: Some(25),
512        };
513        let key = make_key(1);
514
515        tracker.snapshot(key, &hero);
516
517        let modified_hero = TestHero {
518            id: 1,
519            name: "Peter Parker".to_string(),
520            age: Some(25),
521        };
522
523        let changed = tracker.changed_fields(&key, &modified_hero);
524        assert_eq!(changed, vec!["name"]);
525    }
526
527    #[test]
528    fn test_changed_fields_multiple_changes() {
529        let mut tracker = ChangeTracker::new();
530        let hero = TestHero {
531            id: 1,
532            name: "Spider-Man".to_string(),
533            age: Some(25),
534        };
535        let key = make_key(1);
536
537        tracker.snapshot(key, &hero);
538
539        let modified_hero = TestHero {
540            id: 1,
541            name: "Peter Parker".to_string(),
542            age: Some(30),
543        };
544
545        let changed = tracker.changed_fields(&key, &modified_hero);
546        assert!(changed.contains(&"name"));
547        assert!(changed.contains(&"age"));
548        assert!(!changed.contains(&"id"));
549    }
550
551    #[test]
552    fn test_clear_removes_snapshot() {
553        let mut tracker = ChangeTracker::new();
554        let hero = TestHero {
555            id: 1,
556            name: "Spider-Man".to_string(),
557            age: Some(25),
558        };
559        let key = make_key(1);
560
561        tracker.snapshot(key, &hero);
562        assert!(tracker.has_snapshot(&key));
563
564        tracker.clear(&key);
565        assert!(!tracker.has_snapshot(&key));
566    }
567
568    #[test]
569    fn test_clear_all_removes_all() {
570        let mut tracker = ChangeTracker::new();
571
572        let hero1 = TestHero {
573            id: 1,
574            name: "Spider-Man".to_string(),
575            age: Some(25),
576        };
577        let hero2 = TestHero {
578            id: 2,
579            name: "Iron Man".to_string(),
580            age: Some(40),
581        };
582
583        tracker.snapshot(make_key(1), &hero1);
584        tracker.snapshot(make_key(2), &hero2);
585
586        assert_eq!(tracker.len(), 2);
587
588        tracker.clear_all();
589
590        assert!(tracker.is_empty());
591    }
592
593    #[test]
594    fn test_refresh_updates_baseline() {
595        let mut tracker = ChangeTracker::new();
596        let hero = TestHero {
597            id: 1,
598            name: "Spider-Man".to_string(),
599            age: Some(25),
600        };
601        let key = make_key(1);
602
603        tracker.snapshot(key, &hero);
604
605        let modified_hero = TestHero {
606            id: 1,
607            name: "Peter Parker".to_string(),
608            age: Some(25),
609        };
610
611        // Should be dirty before refresh
612        assert!(tracker.is_dirty(&key, &modified_hero));
613
614        // Refresh the baseline
615        tracker.refresh(key, &modified_hero);
616
617        // No longer dirty
618        assert!(!tracker.is_dirty(&key, &modified_hero));
619    }
620
621    #[test]
622    fn test_attribute_changes_empty_when_unchanged() {
623        let mut tracker = ChangeTracker::new();
624        let hero = TestHero {
625            id: 1,
626            name: "Spider-Man".to_string(),
627            age: Some(25),
628        };
629        let key = ObjectKey::from_model(&hero);
630        tracker.snapshot(key, &hero);
631
632        let changes = tracker.attribute_changes(&key, &hero);
633        assert!(changes.is_empty());
634    }
635
636    #[test]
637    fn test_attribute_changes_detects_field_change() {
638        let mut tracker = ChangeTracker::new();
639        let hero = TestHero {
640            id: 1,
641            name: "Spider-Man".to_string(),
642            age: Some(25),
643        };
644        let key = ObjectKey::from_model(&hero);
645        tracker.snapshot(key, &hero);
646
647        let modified = TestHero {
648            id: 1,
649            name: "Peter Parker".to_string(),
650            age: Some(26),
651        };
652
653        let changes = tracker.attribute_changes(&key, &modified);
654        assert_eq!(changes.len(), 2);
655        assert_eq!(changes[0].field_name, "name");
656        assert_eq!(changes[0].old_value, serde_json::json!("Spider-Man"));
657        assert_eq!(changes[0].new_value, serde_json::json!("Peter Parker"));
658        assert_eq!(changes[1].field_name, "age");
659        assert_eq!(changes[1].old_value, serde_json::json!(25));
660        assert_eq!(changes[1].new_value, serde_json::json!(26));
661    }
662
663    #[test]
664    fn test_attribute_changes_empty_without_snapshot() {
665        let tracker = ChangeTracker::new();
666        let hero = TestHero {
667            id: 1,
668            name: "Spider-Man".to_string(),
669            age: Some(25),
670        };
671        let key = ObjectKey::from_model(&hero);
672
673        // No snapshot → empty changes (not all fields)
674        let changes = tracker.attribute_changes(&key, &hero);
675        assert!(changes.is_empty());
676    }
677}