Skip to main content

sqlmodel_session/
lib.rs

1//! Session and Unit of Work for SQLModel Rust.
2//!
3//! `sqlmodel-session` is the **unit-of-work layer**. It coordinates object identity,
4//! change tracking, and transactional persistence in a way that mirrors Python SQLModel
5//! while staying explicit and Rust-idiomatic.
6//!
7//! # Role In The Architecture
8//!
9//! - **Identity map**: ensures a single in-memory instance per primary key.
10//! - **Change tracking**: records inserts, updates, and deletes before flush.
11//! - **Transactional safety**: wraps flush/commit/rollback around a `Connection`.
12//!
13//! # Design Philosophy
14//!
15//! - **Explicit over implicit**: No autoflush by default.
16//! - **Ownership clarity**: Session owns the connection or pooled connection.
17//! - **Type erasure**: Identity map stores `Box<dyn Any>` for heterogeneous models.
18//! - **Cancel-correct**: All async operations use `Cx` + `Outcome` via `sqlmodel-core`.
19//!
20//! # Example
21//!
22//! ```ignore
23//! // Create session from pool
24//! let mut session = Session::new(&pool).await?;
25//!
26//! // Add new objects (will be INSERTed on flush)
27//! session.add(&hero);
28//!
29//! // Get by primary key (uses identity map)
30//! let hero = session.get::<Hero>(1).await?;
31//!
32//! // Mark for deletion
33//! session.delete(&hero);
34//!
35//! // Flush pending changes to DB
36//! session.flush().await?;
37//!
38//! // Commit the transaction
39//! session.commit().await?;
40//! ```
41
42pub mod change_tracker;
43pub mod flush;
44pub mod identity_map;
45pub mod n1_detection;
46pub mod unit_of_work;
47
48pub use change_tracker::{ChangeTracker, ObjectSnapshot};
49pub use flush::{
50    FlushOrderer, FlushPlan, FlushResult, LinkTableOp, PendingOp, execute_link_table_ops,
51};
52pub use identity_map::{IdentityMap, ModelReadGuard, ModelRef, ModelWriteGuard, WeakIdentityMap};
53pub use n1_detection::{CallSite, N1DetectionScope, N1QueryTracker, N1Stats};
54pub use unit_of_work::{PendingCounts, UnitOfWork, UowError};
55
56use asupersync::{Cx, Outcome};
57use serde::{Deserialize, Serialize};
58use sqlmodel_core::{Connection, Error, Lazy, LazyLoader, Model, Value};
59use std::any::{Any, TypeId};
60use std::collections::HashMap;
61use std::future::Future;
62use std::hash::{Hash, Hasher};
63
64// ============================================================================
65// Session Events
66// ============================================================================
67
68/// Type alias for session event callbacks.
69///
70/// Callbacks receive no arguments and return `Result<(), Error>`.
71/// Returning `Err` will abort the operation (e.g., prevent commit).
72type SessionEventFn = Box<dyn FnMut() -> Result<(), Error> + Send>;
73
74/// Holds registered session-level event callbacks.
75///
76/// These are fired at key points in the session lifecycle:
77/// before/after flush, commit, and rollback.
78#[derive(Default)]
79pub struct SessionEventCallbacks {
80    before_flush: Vec<SessionEventFn>,
81    after_flush: Vec<SessionEventFn>,
82    before_commit: Vec<SessionEventFn>,
83    after_commit: Vec<SessionEventFn>,
84    after_rollback: Vec<SessionEventFn>,
85}
86
87impl std::fmt::Debug for SessionEventCallbacks {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        f.debug_struct("SessionEventCallbacks")
90            .field("before_flush", &self.before_flush.len())
91            .field("after_flush", &self.after_flush.len())
92            .field("before_commit", &self.before_commit.len())
93            .field("after_commit", &self.after_commit.len())
94            .field("after_rollback", &self.after_rollback.len())
95            .finish()
96    }
97}
98
99impl SessionEventCallbacks {
100    #[allow(clippy::result_large_err)]
101    fn fire(&mut self, event: SessionEvent) -> Result<(), Error> {
102        let callbacks = match event {
103            SessionEvent::BeforeFlush => &mut self.before_flush,
104            SessionEvent::AfterFlush => &mut self.after_flush,
105            SessionEvent::BeforeCommit => &mut self.before_commit,
106            SessionEvent::AfterCommit => &mut self.after_commit,
107            SessionEvent::AfterRollback => &mut self.after_rollback,
108        };
109        for cb in callbacks.iter_mut() {
110            cb()?;
111        }
112        Ok(())
113    }
114}
115
116/// Session lifecycle events.
117#[derive(Debug, Clone, Copy, PartialEq, Eq)]
118pub enum SessionEvent {
119    /// Fired before flush executes pending changes.
120    BeforeFlush,
121    /// Fired after flush completes successfully.
122    AfterFlush,
123    /// Fired before commit (after flush).
124    BeforeCommit,
125    /// Fired after commit completes successfully.
126    AfterCommit,
127    /// Fired after rollback completes.
128    AfterRollback,
129}
130
131// ============================================================================
132// Session Configuration
133// ============================================================================
134
135/// Configuration for Session behavior.
136#[derive(Debug, Clone)]
137pub struct SessionConfig {
138    /// Whether to auto-begin a transaction on first operation.
139    pub auto_begin: bool,
140    /// Whether to auto-flush before queries (not recommended for performance).
141    pub auto_flush: bool,
142    /// Whether to expire objects after commit (reload from DB on next access).
143    pub expire_on_commit: bool,
144}
145
146impl Default for SessionConfig {
147    fn default() -> Self {
148        Self {
149            auto_begin: true,
150            auto_flush: false,
151            expire_on_commit: true,
152        }
153    }
154}
155
156/// Options for `Session::get_with_options()`.
157#[derive(Debug, Clone, Default)]
158pub struct GetOptions {
159    /// If true, use SELECT ... FOR UPDATE to lock the row.
160    pub with_for_update: bool,
161    /// If true, use SKIP LOCKED with FOR UPDATE (requires `with_for_update`).
162    pub skip_locked: bool,
163    /// If true, use NOWAIT with FOR UPDATE (requires `with_for_update`).
164    pub nowait: bool,
165}
166
167impl GetOptions {
168    /// Create new default options.
169    #[must_use]
170    pub fn new() -> Self {
171        Self::default()
172    }
173
174    /// Set the `with_for_update` option (builder pattern).
175    #[must_use]
176    pub fn with_for_update(mut self, value: bool) -> Self {
177        self.with_for_update = value;
178        self
179    }
180
181    /// Set the `skip_locked` option (builder pattern).
182    #[must_use]
183    pub fn skip_locked(mut self, value: bool) -> Self {
184        self.skip_locked = value;
185        self
186    }
187
188    /// Set the `nowait` option (builder pattern).
189    #[must_use]
190    pub fn nowait(mut self, value: bool) -> Self {
191        self.nowait = value;
192        self
193    }
194}
195
196// ============================================================================
197// Object Key and State
198// ============================================================================
199
200/// Unique key for an object in the identity map.
201#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
202pub struct ObjectKey {
203    /// Type identifier for the Model type.
204    type_id: TypeId,
205    /// Hash of the primary key value(s).
206    pk_hash: u64,
207}
208
209impl ObjectKey {
210    /// Create an object key from a model instance.
211    pub fn from_model<M: Model + 'static>(obj: &M) -> Self {
212        let pk_values = obj.primary_key_value();
213        Self {
214            type_id: TypeId::of::<M>(),
215            pk_hash: hash_values(&pk_values),
216        }
217    }
218
219    /// Create an object key from type and primary key.
220    pub fn from_pk<M: Model + 'static>(pk: &[Value]) -> Self {
221        Self {
222            type_id: TypeId::of::<M>(),
223            pk_hash: hash_values(pk),
224        }
225    }
226
227    /// Get the primary key hash.
228    pub fn pk_hash(&self) -> u64 {
229        self.pk_hash
230    }
231
232    /// Get the type identifier.
233    pub fn type_id(&self) -> TypeId {
234        self.type_id
235    }
236}
237
238/// Hash a slice of values for use as a primary key hash.
239fn hash_values(values: &[Value]) -> u64 {
240    use std::collections::hash_map::DefaultHasher;
241    let mut hasher = DefaultHasher::new();
242    for v in values {
243        // Hash based on value variant and content
244        match v {
245            Value::Null => 0u8.hash(&mut hasher),
246            Value::Bool(b) => {
247                1u8.hash(&mut hasher);
248                b.hash(&mut hasher);
249            }
250            Value::TinyInt(i) => {
251                2u8.hash(&mut hasher);
252                i.hash(&mut hasher);
253            }
254            Value::SmallInt(i) => {
255                3u8.hash(&mut hasher);
256                i.hash(&mut hasher);
257            }
258            Value::Int(i) => {
259                4u8.hash(&mut hasher);
260                i.hash(&mut hasher);
261            }
262            Value::BigInt(i) => {
263                5u8.hash(&mut hasher);
264                i.hash(&mut hasher);
265            }
266            Value::Float(f) => {
267                6u8.hash(&mut hasher);
268                f.to_bits().hash(&mut hasher);
269            }
270            Value::Double(f) => {
271                7u8.hash(&mut hasher);
272                f.to_bits().hash(&mut hasher);
273            }
274            Value::Decimal(s) => {
275                8u8.hash(&mut hasher);
276                s.hash(&mut hasher);
277            }
278            Value::Text(s) => {
279                9u8.hash(&mut hasher);
280                s.hash(&mut hasher);
281            }
282            Value::Bytes(b) => {
283                10u8.hash(&mut hasher);
284                b.hash(&mut hasher);
285            }
286            Value::Date(d) => {
287                11u8.hash(&mut hasher);
288                d.hash(&mut hasher);
289            }
290            Value::Time(t) => {
291                12u8.hash(&mut hasher);
292                t.hash(&mut hasher);
293            }
294            Value::Timestamp(ts) => {
295                13u8.hash(&mut hasher);
296                ts.hash(&mut hasher);
297            }
298            Value::TimestampTz(ts) => {
299                14u8.hash(&mut hasher);
300                ts.hash(&mut hasher);
301            }
302            Value::Uuid(u) => {
303                15u8.hash(&mut hasher);
304                u.hash(&mut hasher);
305            }
306            Value::Json(j) => {
307                16u8.hash(&mut hasher);
308                // Hash the JSON string representation
309                j.to_string().hash(&mut hasher);
310            }
311            Value::Array(arr) => {
312                17u8.hash(&mut hasher);
313                // Recursively hash array elements
314                arr.len().hash(&mut hasher);
315                for item in arr {
316                    hash_value(item, &mut hasher);
317                }
318            }
319            Value::Default => {
320                18u8.hash(&mut hasher);
321            }
322        }
323    }
324    hasher.finish()
325}
326
327/// Hash a single value into the hasher.
328fn hash_value(v: &Value, hasher: &mut impl Hasher) {
329    match v {
330        Value::Null => 0u8.hash(hasher),
331        Value::Bool(b) => {
332            1u8.hash(hasher);
333            b.hash(hasher);
334        }
335        Value::TinyInt(i) => {
336            2u8.hash(hasher);
337            i.hash(hasher);
338        }
339        Value::SmallInt(i) => {
340            3u8.hash(hasher);
341            i.hash(hasher);
342        }
343        Value::Int(i) => {
344            4u8.hash(hasher);
345            i.hash(hasher);
346        }
347        Value::BigInt(i) => {
348            5u8.hash(hasher);
349            i.hash(hasher);
350        }
351        Value::Float(f) => {
352            6u8.hash(hasher);
353            f.to_bits().hash(hasher);
354        }
355        Value::Double(f) => {
356            7u8.hash(hasher);
357            f.to_bits().hash(hasher);
358        }
359        Value::Decimal(s) => {
360            8u8.hash(hasher);
361            s.hash(hasher);
362        }
363        Value::Text(s) => {
364            9u8.hash(hasher);
365            s.hash(hasher);
366        }
367        Value::Bytes(b) => {
368            10u8.hash(hasher);
369            b.hash(hasher);
370        }
371        Value::Date(d) => {
372            11u8.hash(hasher);
373            d.hash(hasher);
374        }
375        Value::Time(t) => {
376            12u8.hash(hasher);
377            t.hash(hasher);
378        }
379        Value::Timestamp(ts) => {
380            13u8.hash(hasher);
381            ts.hash(hasher);
382        }
383        Value::TimestampTz(ts) => {
384            14u8.hash(hasher);
385            ts.hash(hasher);
386        }
387        Value::Uuid(u) => {
388            15u8.hash(hasher);
389            u.hash(hasher);
390        }
391        Value::Json(j) => {
392            16u8.hash(hasher);
393            j.to_string().hash(hasher);
394        }
395        Value::Array(arr) => {
396            17u8.hash(hasher);
397            arr.len().hash(hasher);
398            for item in arr {
399                hash_value(item, hasher);
400            }
401        }
402        Value::Default => {
403            18u8.hash(hasher);
404        }
405    }
406}
407
408/// State of a tracked object in the session.
409#[derive(Debug, Clone, Copy, PartialEq, Eq)]
410pub enum ObjectState {
411    /// New object, needs INSERT on flush.
412    New,
413    /// Persistent object loaded from database.
414    Persistent,
415    /// Object marked for deletion, needs DELETE on flush.
416    Deleted,
417    /// Object detached from session.
418    Detached,
419    /// Object expired, needs reload from database.
420    Expired,
421}
422
423/// A tracked object in the session.
424struct TrackedObject {
425    /// The actual object (type-erased).
426    object: Box<dyn Any + Send + Sync>,
427    /// Original serialized state for dirty checking.
428    original_state: Option<Vec<u8>>,
429    /// Current object state.
430    state: ObjectState,
431    /// Table name for this object.
432    table_name: &'static str,
433    /// Column names for this object.
434    column_names: Vec<&'static str>,
435    /// Current values for each column (for INSERT/UPDATE).
436    values: Vec<Value>,
437    /// Primary key column names.
438    pk_columns: Vec<&'static str>,
439    /// Primary key values (for DELETE/UPDATE WHERE clause).
440    pk_values: Vec<Value>,
441    /// Set of expired attribute names (None = all expired, Some(empty) = none expired).
442    /// When Some(non-empty), only those specific attributes need reload.
443    expired_attributes: Option<std::collections::HashSet<String>>,
444}
445
446// ============================================================================
447// Session
448// ============================================================================
449
450/// The Session is the central unit-of-work manager.
451///
452/// It tracks objects loaded from or added to the database and coordinates
453/// flushing changes back to the database.
454pub struct Session<C: Connection> {
455    /// The database connection.
456    connection: C,
457    /// Whether we're in a transaction.
458    in_transaction: bool,
459    /// Identity map: ObjectKey -> TrackedObject.
460    identity_map: HashMap<ObjectKey, TrackedObject>,
461    /// Objects marked as new (need INSERT).
462    pending_new: Vec<ObjectKey>,
463    /// Objects marked as deleted (need DELETE).
464    pending_delete: Vec<ObjectKey>,
465    /// Objects that are dirty (need UPDATE).
466    pending_dirty: Vec<ObjectKey>,
467    /// Configuration.
468    config: SessionConfig,
469    /// N+1 query detection tracker (optional).
470    n1_tracker: Option<N1QueryTracker>,
471    /// Session-level event callbacks.
472    event_callbacks: SessionEventCallbacks,
473}
474
475impl<C: Connection> Session<C> {
476    /// Create a new session from an existing connection.
477    pub fn new(connection: C) -> Self {
478        Self::with_config(connection, SessionConfig::default())
479    }
480
481    /// Create a new session with custom configuration.
482    pub fn with_config(connection: C, config: SessionConfig) -> Self {
483        Self {
484            connection,
485            in_transaction: false,
486            identity_map: HashMap::new(),
487            pending_new: Vec::new(),
488            pending_delete: Vec::new(),
489            pending_dirty: Vec::new(),
490            config,
491            n1_tracker: None,
492            event_callbacks: SessionEventCallbacks::default(),
493        }
494    }
495
496    /// Get a reference to the underlying connection.
497    pub fn connection(&self) -> &C {
498        &self.connection
499    }
500
501    /// Get the session configuration.
502    pub fn config(&self) -> &SessionConfig {
503        &self.config
504    }
505
506    // ========================================================================
507    // Session Events
508    // ========================================================================
509
510    /// Register a callback to run before flush.
511    ///
512    /// The callback can abort the flush by returning `Err`.
513    pub fn on_before_flush(&mut self, f: impl FnMut() -> Result<(), Error> + Send + 'static) {
514        self.event_callbacks.before_flush.push(Box::new(f));
515    }
516
517    /// Register a callback to run after a successful flush.
518    pub fn on_after_flush(&mut self, f: impl FnMut() -> Result<(), Error> + Send + 'static) {
519        self.event_callbacks.after_flush.push(Box::new(f));
520    }
521
522    /// Register a callback to run before commit (after flush).
523    ///
524    /// The callback can abort the commit by returning `Err`.
525    pub fn on_before_commit(&mut self, f: impl FnMut() -> Result<(), Error> + Send + 'static) {
526        self.event_callbacks.before_commit.push(Box::new(f));
527    }
528
529    /// Register a callback to run after a successful commit.
530    pub fn on_after_commit(&mut self, f: impl FnMut() -> Result<(), Error> + Send + 'static) {
531        self.event_callbacks.after_commit.push(Box::new(f));
532    }
533
534    /// Register a callback to run after rollback.
535    pub fn on_after_rollback(&mut self, f: impl FnMut() -> Result<(), Error> + Send + 'static) {
536        self.event_callbacks.after_rollback.push(Box::new(f));
537    }
538
539    // ========================================================================
540    // Object Tracking
541    // ========================================================================
542
543    /// Add a new object to the session.
544    ///
545    /// The object will be INSERTed on the next `flush()` call.
546    pub fn add<M: Model + Clone + Send + Sync + Serialize + 'static>(&mut self, obj: &M) {
547        let key = ObjectKey::from_model(obj);
548
549        // If already tracked, update the object and its values
550        if let Some(tracked) = self.identity_map.get_mut(&key) {
551            tracked.object = Box::new(obj.clone());
552
553            // Update stored values to match the new object state
554            let row_data = obj.to_row();
555            tracked.column_names = row_data.iter().map(|(name, _)| *name).collect();
556            tracked.values = row_data.into_iter().map(|(_, v)| v).collect();
557            tracked.pk_values = obj.primary_key_value();
558
559            if tracked.state == ObjectState::Deleted {
560                // Un-delete: remove from pending_delete and restore state
561                self.pending_delete.retain(|k| k != &key);
562
563                if tracked.original_state.is_some() {
564                    // Was previously persisted - restore to Persistent (will need UPDATE if changed)
565                    tracked.state = ObjectState::Persistent;
566                } else {
567                    // Was never persisted - restore to New and schedule for INSERT
568                    tracked.state = ObjectState::New;
569                    if !self.pending_new.contains(&key) {
570                        self.pending_new.push(key);
571                    }
572                }
573            }
574            return;
575        }
576
577        // Extract column data from the model while we have the concrete type
578        let row_data = obj.to_row();
579        let column_names: Vec<&'static str> = row_data.iter().map(|(name, _)| *name).collect();
580        let values: Vec<Value> = row_data.into_iter().map(|(_, v)| v).collect();
581
582        // Extract primary key info
583        let pk_columns: Vec<&'static str> = M::PRIMARY_KEY.to_vec();
584        let pk_values = obj.primary_key_value();
585
586        let tracked = TrackedObject {
587            object: Box::new(obj.clone()),
588            original_state: None, // New objects have no original state
589            state: ObjectState::New,
590            table_name: M::TABLE_NAME,
591            column_names,
592            values,
593            pk_columns,
594            pk_values,
595            expired_attributes: None,
596        };
597
598        self.identity_map.insert(key, tracked);
599        self.pending_new.push(key);
600    }
601
602    /// Add multiple objects to the session at once.
603    ///
604    /// This is equivalent to calling `add()` for each object, but provides a more
605    /// convenient API for bulk operations.
606    ///
607    /// # Example
608    ///
609    /// ```ignore
610    /// let users = vec![user1, user2, user3];
611    /// session.add_all(&users);
612    ///
613    /// // Or with an iterator
614    /// session.add_all(users.iter());
615    /// ```
616    ///
617    /// All objects will be INSERTed on the next `flush()` call.
618    pub fn add_all<'a, M, I>(&mut self, objects: I)
619    where
620        M: Model + Clone + Send + Sync + Serialize + 'static,
621        I: IntoIterator<Item = &'a M>,
622    {
623        for obj in objects {
624            self.add(obj);
625        }
626    }
627
628    /// Delete an object from the session.
629    ///
630    /// The object will be DELETEd on the next `flush()` call.
631    pub fn delete<M: Model + 'static>(&mut self, obj: &M) {
632        let key = ObjectKey::from_model(obj);
633
634        if let Some(tracked) = self.identity_map.get_mut(&key) {
635            match tracked.state {
636                ObjectState::New => {
637                    // If it's new, just remove it entirely
638                    self.identity_map.remove(&key);
639                    self.pending_new.retain(|k| k != &key);
640                }
641                ObjectState::Persistent | ObjectState::Expired => {
642                    tracked.state = ObjectState::Deleted;
643                    self.pending_delete.push(key);
644                    self.pending_dirty.retain(|k| k != &key);
645                }
646                ObjectState::Deleted | ObjectState::Detached => {
647                    // Already deleted or detached, nothing to do
648                }
649            }
650        }
651    }
652
653    /// Mark an object as dirty (modified) so it will be UPDATEd on flush.
654    ///
655    /// This updates the stored values from the object and schedules an UPDATE.
656    /// Only works for objects that are already tracked as Persistent.
657    ///
658    /// # Example
659    ///
660    /// ```ignore
661    /// let mut hero = session.get::<Hero>(1).await?.unwrap();
662    /// hero.name = "New Name".to_string();
663    /// session.mark_dirty(&hero);  // Schedule for UPDATE
664    /// session.flush(cx).await?;   // Execute the UPDATE
665    /// ```
666    pub fn mark_dirty<M: Model + Clone + Send + Sync + Serialize + 'static>(&mut self, obj: &M) {
667        let key = ObjectKey::from_model(obj);
668
669        if let Some(tracked) = self.identity_map.get_mut(&key) {
670            // Only mark persistent objects as dirty
671            if tracked.state != ObjectState::Persistent {
672                return;
673            }
674
675            // Update the stored object and values
676            tracked.object = Box::new(obj.clone());
677            let row_data = obj.to_row();
678            tracked.column_names = row_data.iter().map(|(name, _)| *name).collect();
679            tracked.values = row_data.into_iter().map(|(_, v)| v).collect();
680            tracked.pk_values = obj.primary_key_value();
681
682            // Add to pending dirty if not already there
683            if !self.pending_dirty.contains(&key) {
684                self.pending_dirty.push(key);
685            }
686        }
687    }
688
689    /// Get an object by primary key.
690    ///
691    /// First checks the identity map, then queries the database if not found.
692    pub async fn get<
693        M: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
694    >(
695        &mut self,
696        cx: &Cx,
697        pk: impl Into<Value>,
698    ) -> Outcome<Option<M>, Error> {
699        let pk_value = pk.into();
700        let pk_values = vec![pk_value.clone()];
701        let key = ObjectKey::from_pk::<M>(&pk_values);
702
703        // Check identity map first (skip if expired - will reload below)
704        if let Some(tracked) = self.identity_map.get(&key) {
705            match tracked.state {
706                ObjectState::Deleted | ObjectState::Detached => {
707                    // Return None for deleted/detached objects
708                }
709                ObjectState::Expired => {
710                    // Skip cache, will reload from DB below
711                    tracing::debug!("Object is expired, reloading from database");
712                }
713                ObjectState::New | ObjectState::Persistent => {
714                    if let Some(obj) = tracked.object.downcast_ref::<M>() {
715                        return Outcome::Ok(Some(obj.clone()));
716                    }
717                }
718            }
719        }
720
721        // Query from database
722        let pk_col = M::PRIMARY_KEY.first().unwrap_or(&"id");
723        let sql = format!(
724            "SELECT * FROM \"{}\" WHERE \"{}\" = $1 LIMIT 1",
725            M::TABLE_NAME,
726            pk_col
727        );
728
729        let rows = match self.connection.query(cx, &sql, &[pk_value]).await {
730            Outcome::Ok(rows) => rows,
731            Outcome::Err(e) => return Outcome::Err(e),
732            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
733            Outcome::Panicked(p) => return Outcome::Panicked(p),
734        };
735
736        if rows.is_empty() {
737            return Outcome::Ok(None);
738        }
739
740        // Convert row to model
741        let obj = match M::from_row(&rows[0]) {
742            Ok(obj) => obj,
743            Err(e) => return Outcome::Err(e),
744        };
745
746        // Extract column data from the model while we have the concrete type
747        let row_data = obj.to_row();
748        let column_names: Vec<&'static str> = row_data.iter().map(|(name, _)| *name).collect();
749        let values: Vec<Value> = row_data.into_iter().map(|(_, v)| v).collect();
750
751        // Serialize values for dirty checking (must match format used in flush)
752        let serialized = serde_json::to_vec(&values).ok();
753
754        // Extract primary key info
755        let pk_columns: Vec<&'static str> = M::PRIMARY_KEY.to_vec();
756        let obj_pk_values = obj.primary_key_value();
757
758        let tracked = TrackedObject {
759            object: Box::new(obj.clone()),
760            original_state: serialized,
761            state: ObjectState::Persistent,
762            table_name: M::TABLE_NAME,
763            column_names,
764            values,
765            pk_columns,
766            pk_values: obj_pk_values,
767            expired_attributes: None,
768        };
769
770        self.identity_map.insert(key, tracked);
771
772        Outcome::Ok(Some(obj))
773    }
774
775    /// Get an object by composite primary key.
776    ///
777    /// First checks the identity map, then queries the database if not found.
778    ///
779    /// # Example
780    ///
781    /// ```ignore
782    /// // Composite PK lookup
783    /// let item = session.get_by_pk::<OrderItem>(&[
784    ///     Value::BigInt(order_id),
785    ///     Value::BigInt(product_id),
786    /// ]).await?;
787    /// ```
788    pub async fn get_by_pk<
789        M: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
790    >(
791        &mut self,
792        cx: &Cx,
793        pk_values: &[Value],
794    ) -> Outcome<Option<M>, Error> {
795        self.get_with_options::<M>(cx, pk_values, &GetOptions::default())
796            .await
797    }
798
799    /// Get an object by primary key with options.
800    ///
801    /// This is the most flexible form of `get()` supporting:
802    /// - Composite primary keys via `&[Value]`
803    /// - `with_for_update` for row locking
804    ///
805    /// # Example
806    ///
807    /// ```ignore
808    /// let options = GetOptions::default().with_for_update(true);
809    /// let user = session.get_with_options::<User>(&[Value::BigInt(1)], &options).await?;
810    /// ```
811    pub async fn get_with_options<
812        M: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
813    >(
814        &mut self,
815        cx: &Cx,
816        pk_values: &[Value],
817        options: &GetOptions,
818    ) -> Outcome<Option<M>, Error> {
819        let key = ObjectKey::from_pk::<M>(pk_values);
820
821        // Check identity map first (unless with_for_update which needs fresh DB state)
822        if !options.with_for_update {
823            if let Some(tracked) = self.identity_map.get(&key) {
824                match tracked.state {
825                    ObjectState::Deleted | ObjectState::Detached => {
826                        // Return None for deleted/detached objects
827                    }
828                    ObjectState::Expired => {
829                        // Skip cache, will reload from DB below
830                        tracing::debug!("Object is expired, reloading from database");
831                    }
832                    ObjectState::New | ObjectState::Persistent => {
833                        if let Some(obj) = tracked.object.downcast_ref::<M>() {
834                            return Outcome::Ok(Some(obj.clone()));
835                        }
836                    }
837                }
838            }
839        }
840
841        // Build WHERE clause for composite PK
842        let pk_columns = M::PRIMARY_KEY;
843        if pk_columns.len() != pk_values.len() {
844            return Outcome::Err(Error::Custom(format!(
845                "Primary key mismatch: expected {} values, got {}",
846                pk_columns.len(),
847                pk_values.len()
848            )));
849        }
850
851        let where_parts: Vec<String> = pk_columns
852            .iter()
853            .enumerate()
854            .map(|(i, col)| format!("\"{}\" = ${}", col, i + 1))
855            .collect();
856
857        let mut sql = format!(
858            "SELECT * FROM \"{}\" WHERE {} LIMIT 1",
859            M::TABLE_NAME,
860            where_parts.join(" AND ")
861        );
862
863        // Add FOR UPDATE if requested
864        if options.with_for_update {
865            sql.push_str(" FOR UPDATE");
866            if options.skip_locked {
867                sql.push_str(" SKIP LOCKED");
868            } else if options.nowait {
869                sql.push_str(" NOWAIT");
870            }
871        }
872
873        let rows = match self.connection.query(cx, &sql, pk_values).await {
874            Outcome::Ok(rows) => rows,
875            Outcome::Err(e) => return Outcome::Err(e),
876            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
877            Outcome::Panicked(p) => return Outcome::Panicked(p),
878        };
879
880        if rows.is_empty() {
881            return Outcome::Ok(None);
882        }
883
884        // Convert row to model
885        let obj = match M::from_row(&rows[0]) {
886            Ok(obj) => obj,
887            Err(e) => return Outcome::Err(e),
888        };
889
890        // Extract column data from the model while we have the concrete type
891        let row_data = obj.to_row();
892        let column_names: Vec<&'static str> = row_data.iter().map(|(name, _)| *name).collect();
893        let values: Vec<Value> = row_data.into_iter().map(|(_, v)| v).collect();
894
895        // Serialize values for dirty checking
896        let serialized = serde_json::to_vec(&values).ok();
897
898        // Extract primary key info
899        let pk_cols: Vec<&'static str> = M::PRIMARY_KEY.to_vec();
900        let obj_pk_values = obj.primary_key_value();
901
902        let tracked = TrackedObject {
903            object: Box::new(obj.clone()),
904            original_state: serialized,
905            state: ObjectState::Persistent,
906            table_name: M::TABLE_NAME,
907            column_names,
908            values,
909            pk_columns: pk_cols,
910            pk_values: obj_pk_values,
911            expired_attributes: None,
912        };
913
914        self.identity_map.insert(key, tracked);
915
916        Outcome::Ok(Some(obj))
917    }
918
919    /// Check if an object is tracked by this session.
920    pub fn contains<M: Model + 'static>(&self, obj: &M) -> bool {
921        let key = ObjectKey::from_model(obj);
922        self.identity_map.contains_key(&key)
923    }
924
925    /// Detach an object from the session.
926    pub fn expunge<M: Model + 'static>(&mut self, obj: &M) {
927        let key = ObjectKey::from_model(obj);
928        if let Some(tracked) = self.identity_map.get_mut(&key) {
929            tracked.state = ObjectState::Detached;
930        }
931        self.pending_new.retain(|k| k != &key);
932        self.pending_delete.retain(|k| k != &key);
933        self.pending_dirty.retain(|k| k != &key);
934    }
935
936    /// Detach all objects from the session.
937    pub fn expunge_all(&mut self) {
938        for tracked in self.identity_map.values_mut() {
939            tracked.state = ObjectState::Detached;
940        }
941        self.pending_new.clear();
942        self.pending_delete.clear();
943        self.pending_dirty.clear();
944    }
945
946    // ========================================================================
947    // Dirty Checking
948    // ========================================================================
949
950    /// Check if an object has pending changes.
951    ///
952    /// Returns `true` if:
953    /// - Object is new (pending INSERT)
954    /// - Object has been modified since load (pending UPDATE)
955    /// - Object is marked for deletion (pending DELETE)
956    ///
957    /// Returns `false` if:
958    /// - Object is not tracked
959    /// - Object is clean (unchanged since load)
960    /// - Object is detached or expired
961    ///
962    /// # Example
963    ///
964    /// ```ignore
965    /// let user = session.get::<User>(1).await?.unwrap();
966    /// assert!(!session.is_modified(&user));  // Fresh from DB
967    ///
968    /// // Modify and re-check
969    /// let mut user_mut = user.clone();
970    /// user_mut.name = "New Name".to_string();
971    /// session.mark_dirty(&user_mut);
972    /// assert!(session.is_modified(&user_mut));  // Now dirty
973    /// ```
974    pub fn is_modified<M: Model + Serialize + 'static>(&self, obj: &M) -> bool {
975        let key = ObjectKey::from_model(obj);
976
977        let Some(tracked) = self.identity_map.get(&key) else {
978            return false;
979        };
980
981        match tracked.state {
982            // New objects are always "modified" (pending INSERT)
983            ObjectState::New => true,
984
985            // Deleted objects are "modified" (pending DELETE)
986            ObjectState::Deleted => true,
987
988            // Detached/expired objects aren't modified in session context
989            ObjectState::Detached | ObjectState::Expired => false,
990
991            // For persistent objects, compare current values to original
992            ObjectState::Persistent => {
993                // Check if explicitly marked dirty
994                if self.pending_dirty.contains(&key) {
995                    return true;
996                }
997
998                // Compare serialized values
999                let current_state = serde_json::to_vec(&tracked.values).unwrap_or_default();
1000                tracked.original_state.as_ref() != Some(&current_state)
1001            }
1002        }
1003    }
1004
1005    /// Get the list of modified attribute names for an object.
1006    ///
1007    /// Returns the column names that have changed since the object was loaded.
1008    /// Returns an empty vector if:
1009    /// - Object is not tracked
1010    /// - Object is new (all fields are "modified")
1011    /// - Object is clean (no changes)
1012    ///
1013    /// # Example
1014    ///
1015    /// ```ignore
1016    /// let mut user = session.get::<User>(1).await?.unwrap();
1017    /// user.name = "New Name".to_string();
1018    /// session.mark_dirty(&user);
1019    ///
1020    /// let changed = session.modified_attributes(&user);
1021    /// assert!(changed.contains(&"name"));
1022    /// ```
1023    pub fn modified_attributes<M: Model + Serialize + 'static>(
1024        &self,
1025        obj: &M,
1026    ) -> Vec<&'static str> {
1027        let key = ObjectKey::from_model(obj);
1028
1029        let Some(tracked) = self.identity_map.get(&key) else {
1030            return Vec::new();
1031        };
1032
1033        // Only meaningful for persistent objects
1034        if tracked.state != ObjectState::Persistent {
1035            return Vec::new();
1036        }
1037
1038        // Need original state for comparison
1039        let Some(original_bytes) = &tracked.original_state else {
1040            return Vec::new();
1041        };
1042
1043        // Deserialize original values
1044        let Ok(original_values): Result<Vec<Value>, _> = serde_json::from_slice(original_bytes)
1045        else {
1046            return Vec::new();
1047        };
1048
1049        // Compare each column
1050        let mut modified = Vec::new();
1051        for (i, col) in tracked.column_names.iter().enumerate() {
1052            let current = tracked.values.get(i);
1053            let original = original_values.get(i);
1054
1055            if current != original {
1056                modified.push(*col);
1057            }
1058        }
1059
1060        modified
1061    }
1062
1063    /// Get the state of a tracked object.
1064    ///
1065    /// Returns `None` if the object is not tracked by this session.
1066    pub fn object_state<M: Model + 'static>(&self, obj: &M) -> Option<ObjectState> {
1067        let key = ObjectKey::from_model(obj);
1068        self.identity_map.get(&key).map(|t| t.state)
1069    }
1070
1071    // ========================================================================
1072    // Expiration
1073    // ========================================================================
1074
1075    /// Expire an object's cached attributes, forcing reload on next access.
1076    ///
1077    /// After calling this method, the next `get()` call for this object will reload
1078    /// from the database instead of returning the cached version.
1079    ///
1080    /// # Arguments
1081    ///
1082    /// * `obj` - The object to expire.
1083    /// * `attributes` - Optional list of attribute names to expire. If `None`, all
1084    ///   attributes are expired.
1085    ///
1086    /// # Example
1087    ///
1088    /// ```ignore
1089    /// // Expire all attributes
1090    /// session.expire(&user, None);
1091    ///
1092    /// // Expire specific attributes
1093    /// session.expire(&user, Some(&["name", "email"]));
1094    ///
1095    /// // Next get() will reload from database
1096    /// let refreshed = session.get::<User>(cx, user.id).await?;
1097    /// ```
1098    ///
1099    /// # Notes
1100    ///
1101    /// - Expiring an object does not discard pending changes. If the object has been
1102    ///   modified but not flushed, those changes remain pending.
1103    /// - Expiring a detached or new object has no effect.
1104    #[tracing::instrument(level = "debug", skip(self, obj), fields(table = M::TABLE_NAME))]
1105    pub fn expire<M: Model + 'static>(&mut self, obj: &M, attributes: Option<&[&str]>) {
1106        let key = ObjectKey::from_model(obj);
1107
1108        let Some(tracked) = self.identity_map.get_mut(&key) else {
1109            tracing::debug!("Object not tracked, nothing to expire");
1110            return;
1111        };
1112
1113        // Only expire persistent objects
1114        match tracked.state {
1115            ObjectState::New | ObjectState::Detached | ObjectState::Deleted => {
1116                tracing::debug!(state = ?tracked.state, "Cannot expire object in this state");
1117                return;
1118            }
1119            ObjectState::Persistent | ObjectState::Expired => {}
1120        }
1121
1122        match attributes {
1123            None => {
1124                // Expire all attributes
1125                tracked.state = ObjectState::Expired;
1126                tracked.expired_attributes = None;
1127                tracing::debug!("Expired all attributes");
1128            }
1129            Some(attrs) => {
1130                // Expire specific attributes
1131                let mut expired = tracked.expired_attributes.take().unwrap_or_default();
1132                for attr in attrs {
1133                    expired.insert((*attr).to_string());
1134                }
1135                tracked.expired_attributes = Some(expired);
1136
1137                // If any attributes are expired, mark the object as expired
1138                if tracked.state == ObjectState::Persistent {
1139                    tracked.state = ObjectState::Expired;
1140                }
1141                tracing::debug!(attributes = ?attrs, "Expired specific attributes");
1142            }
1143        }
1144    }
1145
1146    /// Expire all objects in the session.
1147    ///
1148    /// After calling this method, all tracked objects will be marked as expired.
1149    /// The next access to any object will reload from the database.
1150    ///
1151    /// # Example
1152    ///
1153    /// ```ignore
1154    /// // Expire everything in the session
1155    /// session.expire_all();
1156    ///
1157    /// // All subsequent get() calls will reload from database
1158    /// let user = session.get::<User>(cx, 1).await?;  // Reloads from DB
1159    /// let team = session.get::<Team>(cx, 1).await?;  // Reloads from DB
1160    /// ```
1161    ///
1162    /// # Notes
1163    ///
1164    /// - This does not affect new or deleted objects.
1165    /// - Pending changes are not discarded.
1166    #[tracing::instrument(level = "debug", skip(self))]
1167    pub fn expire_all(&mut self) {
1168        let mut expired_count = 0;
1169        for tracked in self.identity_map.values_mut() {
1170            if tracked.state == ObjectState::Persistent {
1171                tracked.state = ObjectState::Expired;
1172                tracked.expired_attributes = None;
1173                expired_count += 1;
1174            }
1175        }
1176        tracing::debug!(count = expired_count, "Expired all session objects");
1177    }
1178
1179    /// Check if an object is expired (needs reload from database).
1180    ///
1181    /// Returns `true` if the object is marked as expired and will be reloaded
1182    /// on the next access.
1183    pub fn is_expired<M: Model + 'static>(&self, obj: &M) -> bool {
1184        let key = ObjectKey::from_model(obj);
1185        self.identity_map
1186            .get(&key)
1187            .is_some_and(|t| t.state == ObjectState::Expired)
1188    }
1189
1190    /// Get the list of expired attribute names for an object.
1191    ///
1192    /// Returns:
1193    /// - `None` if the object is not tracked or not expired
1194    /// - `Some(None)` if all attributes are expired
1195    /// - `Some(Some(set))` if only specific attributes are expired
1196    pub fn expired_attributes<M: Model + 'static>(
1197        &self,
1198        obj: &M,
1199    ) -> Option<Option<&std::collections::HashSet<String>>> {
1200        let key = ObjectKey::from_model(obj);
1201        let tracked = self.identity_map.get(&key)?;
1202
1203        if tracked.state != ObjectState::Expired {
1204            return None;
1205        }
1206
1207        Some(tracked.expired_attributes.as_ref())
1208    }
1209
1210    /// Refresh an object by reloading it from the database.
1211    ///
1212    /// This method immediately reloads the object from the database, updating
1213    /// the cached copy in the session. Unlike `expire()`, which defers the reload
1214    /// until the next access, `refresh()` performs the reload immediately.
1215    ///
1216    /// # Arguments
1217    ///
1218    /// * `cx` - The async context for database operations.
1219    /// * `obj` - The object to refresh.
1220    ///
1221    /// # Returns
1222    ///
1223    /// Returns `Ok(Some(refreshed))` if the object was found in the database,
1224    /// `Ok(None)` if the object no longer exists in the database, or an error.
1225    ///
1226    /// # Example
1227    ///
1228    /// ```ignore
1229    /// // Immediately reload from database
1230    /// let refreshed = session.refresh(&cx, &user).await?;
1231    ///
1232    /// if let Some(user) = refreshed {
1233    ///     println!("Refreshed: {}", user.name);
1234    /// } else {
1235    ///     println!("User was deleted from database");
1236    /// }
1237    /// ```
1238    ///
1239    /// # Notes
1240    ///
1241    /// - This discards any changes in the session's cached copy.
1242    /// - If the object has pending changes, they will be lost.
1243    /// - If the object no longer exists in the database, it is removed from the session.
1244    #[tracing::instrument(level = "debug", skip(self, cx, obj), fields(table = M::TABLE_NAME))]
1245    pub async fn refresh<
1246        M: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
1247    >(
1248        &mut self,
1249        cx: &Cx,
1250        obj: &M,
1251    ) -> Outcome<Option<M>, Error> {
1252        let pk_values = obj.primary_key_value();
1253        let key = ObjectKey::from_model(obj);
1254
1255        tracing::debug!(pk = ?pk_values, "Refreshing object from database");
1256
1257        // Remove from pending queues since we're reloading
1258        self.pending_dirty.retain(|k| k != &key);
1259
1260        // Remove from identity map to force reload
1261        self.identity_map.remove(&key);
1262
1263        // Reload from database
1264        let result = self.get_by_pk::<M>(cx, &pk_values).await;
1265
1266        match &result {
1267            Outcome::Ok(Some(_)) => {
1268                tracing::debug!("Object refreshed successfully");
1269            }
1270            Outcome::Ok(None) => {
1271                tracing::debug!("Object no longer exists in database");
1272            }
1273            _ => {}
1274        }
1275
1276        result
1277    }
1278
1279    // ========================================================================
1280    // Transaction Management
1281    // ========================================================================
1282
1283    /// Begin a transaction.
1284    pub async fn begin(&mut self, cx: &Cx) -> Outcome<(), Error> {
1285        if self.in_transaction {
1286            return Outcome::Ok(());
1287        }
1288
1289        match self.connection.execute(cx, "BEGIN", &[]).await {
1290            Outcome::Ok(_) => {
1291                self.in_transaction = true;
1292                Outcome::Ok(())
1293            }
1294            Outcome::Err(e) => Outcome::Err(e),
1295            Outcome::Cancelled(r) => Outcome::Cancelled(r),
1296            Outcome::Panicked(p) => Outcome::Panicked(p),
1297        }
1298    }
1299
1300    /// Flush pending changes to the database.
1301    ///
1302    /// This executes INSERT, UPDATE, and DELETE statements but does NOT commit.
1303    pub async fn flush(&mut self, cx: &Cx) -> Outcome<(), Error> {
1304        // Fire before_flush event
1305        if let Err(e) = self.event_callbacks.fire(SessionEvent::BeforeFlush) {
1306            return Outcome::Err(e);
1307        }
1308
1309        // Auto-begin transaction if configured
1310        if self.config.auto_begin && !self.in_transaction {
1311            match self.begin(cx).await {
1312                Outcome::Ok(()) => {}
1313                Outcome::Err(e) => return Outcome::Err(e),
1314                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1315                Outcome::Panicked(p) => return Outcome::Panicked(p),
1316            }
1317        }
1318
1319        // 1. Execute DELETEs first (to respect FK constraints)
1320        let deletes: Vec<ObjectKey> = std::mem::take(&mut self.pending_delete);
1321        let mut actually_deleted: Vec<ObjectKey> = Vec::new();
1322        for key in &deletes {
1323            if let Some(tracked) = self.identity_map.get(key) {
1324                // Skip if object was un-deleted (state changed from Deleted)
1325                if tracked.state != ObjectState::Deleted {
1326                    continue;
1327                }
1328
1329                // Skip objects without primary keys - cannot safely DELETE without WHERE clause
1330                if tracked.pk_columns.is_empty() || tracked.pk_values.is_empty() {
1331                    tracing::warn!(
1332                        table = tracked.table_name,
1333                        "Skipping DELETE for object without primary key - cannot identify row"
1334                    );
1335                    continue;
1336                }
1337
1338                // Build WHERE clause from primary key columns and values
1339                let where_parts: Vec<String> = tracked
1340                    .pk_columns
1341                    .iter()
1342                    .enumerate()
1343                    .map(|(i, col)| format!("\"{}\" = ${}", col, i + 1))
1344                    .collect();
1345
1346                let sql = format!(
1347                    "DELETE FROM \"{}\" WHERE {}",
1348                    tracked.table_name,
1349                    where_parts.join(" AND ")
1350                );
1351
1352                match self.connection.execute(cx, &sql, &tracked.pk_values).await {
1353                    Outcome::Ok(_) => {
1354                        actually_deleted.push(*key);
1355                    }
1356                    Outcome::Err(e) => {
1357                        // Only restore deletes that weren't already executed
1358                        // (exclude actually_deleted items from restoration)
1359                        self.pending_delete = deletes
1360                            .into_iter()
1361                            .filter(|k| !actually_deleted.contains(k))
1362                            .collect();
1363                        // Remove successfully deleted objects before returning error
1364                        for key in &actually_deleted {
1365                            self.identity_map.remove(key);
1366                        }
1367                        return Outcome::Err(e);
1368                    }
1369                    Outcome::Cancelled(r) => {
1370                        // Same handling for cancellation
1371                        self.pending_delete = deletes
1372                            .into_iter()
1373                            .filter(|k| !actually_deleted.contains(k))
1374                            .collect();
1375                        for key in &actually_deleted {
1376                            self.identity_map.remove(key);
1377                        }
1378                        return Outcome::Cancelled(r);
1379                    }
1380                    Outcome::Panicked(p) => {
1381                        // Same handling for panic
1382                        self.pending_delete = deletes
1383                            .into_iter()
1384                            .filter(|k| !actually_deleted.contains(k))
1385                            .collect();
1386                        for key in &actually_deleted {
1387                            self.identity_map.remove(key);
1388                        }
1389                        return Outcome::Panicked(p);
1390                    }
1391                }
1392            }
1393        }
1394
1395        // Remove only actually deleted objects from identity map
1396        for key in &actually_deleted {
1397            self.identity_map.remove(key);
1398        }
1399
1400        // 2. Execute INSERTs
1401        let inserts: Vec<ObjectKey> = std::mem::take(&mut self.pending_new);
1402        for key in &inserts {
1403            if let Some(tracked) = self.identity_map.get_mut(key) {
1404                // Skip if already persistent (was inserted in a previous attempt before error)
1405                if tracked.state == ObjectState::Persistent {
1406                    continue;
1407                }
1408
1409                // Build INSERT statement using stored column names and values
1410                let columns = &tracked.column_names;
1411                let placeholders: Vec<String> =
1412                    (1..=columns.len()).map(|i| format!("${}", i)).collect();
1413
1414                let sql = format!(
1415                    "INSERT INTO \"{}\" ({}) VALUES ({})",
1416                    tracked.table_name,
1417                    columns
1418                        .iter()
1419                        .map(|c| format!("\"{}\"", c))
1420                        .collect::<Vec<_>>()
1421                        .join(", "),
1422                    placeholders.join(", ")
1423                );
1424
1425                match self.connection.execute(cx, &sql, &tracked.values).await {
1426                    Outcome::Ok(_) => {
1427                        tracked.state = ObjectState::Persistent;
1428                        // Set original_state for future dirty checking (serialize current values)
1429                        tracked.original_state =
1430                            Some(serde_json::to_vec(&tracked.values).unwrap_or_default());
1431                    }
1432                    Outcome::Err(e) => {
1433                        // Restore pending_new for retry
1434                        self.pending_new = inserts;
1435                        return Outcome::Err(e);
1436                    }
1437                    Outcome::Cancelled(r) => {
1438                        // Restore pending_new for retry (same as Err handling)
1439                        self.pending_new = inserts;
1440                        return Outcome::Cancelled(r);
1441                    }
1442                    Outcome::Panicked(p) => {
1443                        // Restore pending_new for retry (same as Err handling)
1444                        self.pending_new = inserts;
1445                        return Outcome::Panicked(p);
1446                    }
1447                }
1448            }
1449        }
1450
1451        // 3. Execute UPDATEs for dirty objects
1452        let dirty: Vec<ObjectKey> = std::mem::take(&mut self.pending_dirty);
1453        for key in &dirty {
1454            if let Some(tracked) = self.identity_map.get_mut(key) {
1455                // Only UPDATE persistent objects
1456                if tracked.state != ObjectState::Persistent {
1457                    continue;
1458                }
1459
1460                // Skip objects without primary keys - cannot safely UPDATE without WHERE clause
1461                if tracked.pk_columns.is_empty() || tracked.pk_values.is_empty() {
1462                    tracing::warn!(
1463                        table = tracked.table_name,
1464                        "Skipping UPDATE for object without primary key - cannot identify row"
1465                    );
1466                    continue;
1467                }
1468
1469                // Check if actually dirty by comparing serialized state
1470                let current_state = serde_json::to_vec(&tracked.values).unwrap_or_default();
1471                let is_dirty = tracked.original_state.as_ref() != Some(&current_state);
1472
1473                if !is_dirty {
1474                    continue;
1475                }
1476
1477                // Build UPDATE statement with all non-PK columns
1478                let mut set_parts = Vec::new();
1479                let mut params = Vec::new();
1480                let mut param_idx = 1;
1481
1482                for (i, col) in tracked.column_names.iter().enumerate() {
1483                    // Skip primary key columns in SET clause
1484                    if !tracked.pk_columns.contains(col) {
1485                        set_parts.push(format!("\"{}\" = ${}", col, param_idx));
1486                        params.push(tracked.values[i].clone());
1487                        param_idx += 1;
1488                    }
1489                }
1490
1491                // Add WHERE clause for primary key
1492                let where_parts: Vec<String> = tracked
1493                    .pk_columns
1494                    .iter()
1495                    .map(|col| {
1496                        let clause = format!("\"{}\" = ${}", col, param_idx);
1497                        param_idx += 1;
1498                        clause
1499                    })
1500                    .collect();
1501
1502                // Add PK values to params
1503                params.extend(tracked.pk_values.clone());
1504
1505                if set_parts.is_empty() {
1506                    continue; // No non-PK columns to update
1507                }
1508
1509                let sql = format!(
1510                    "UPDATE \"{}\" SET {} WHERE {}",
1511                    tracked.table_name,
1512                    set_parts.join(", "),
1513                    where_parts.join(" AND ")
1514                );
1515
1516                match self.connection.execute(cx, &sql, &params).await {
1517                    Outcome::Ok(_) => {
1518                        // Update original_state to current state
1519                        tracked.original_state = Some(current_state);
1520                    }
1521                    Outcome::Err(e) => {
1522                        // Restore pending_dirty for retry
1523                        self.pending_dirty = dirty;
1524                        return Outcome::Err(e);
1525                    }
1526                    Outcome::Cancelled(r) => {
1527                        // Restore pending_dirty for retry (same as Err handling)
1528                        self.pending_dirty = dirty;
1529                        return Outcome::Cancelled(r);
1530                    }
1531                    Outcome::Panicked(p) => {
1532                        // Restore pending_dirty for retry (same as Err handling)
1533                        self.pending_dirty = dirty;
1534                        return Outcome::Panicked(p);
1535                    }
1536                }
1537            }
1538        }
1539
1540        // Fire after_flush event
1541        if let Err(e) = self.event_callbacks.fire(SessionEvent::AfterFlush) {
1542            return Outcome::Err(e);
1543        }
1544
1545        Outcome::Ok(())
1546    }
1547
1548    /// Commit the current transaction.
1549    pub async fn commit(&mut self, cx: &Cx) -> Outcome<(), Error> {
1550        // Flush any pending changes first
1551        match self.flush(cx).await {
1552            Outcome::Ok(()) => {}
1553            Outcome::Err(e) => return Outcome::Err(e),
1554            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1555            Outcome::Panicked(p) => return Outcome::Panicked(p),
1556        }
1557
1558        // Fire before_commit event (can abort)
1559        if let Err(e) = self.event_callbacks.fire(SessionEvent::BeforeCommit) {
1560            return Outcome::Err(e);
1561        }
1562
1563        if self.in_transaction {
1564            match self.connection.execute(cx, "COMMIT", &[]).await {
1565                Outcome::Ok(_) => {
1566                    self.in_transaction = false;
1567                }
1568                Outcome::Err(e) => return Outcome::Err(e),
1569                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1570                Outcome::Panicked(p) => return Outcome::Panicked(p),
1571            }
1572        }
1573
1574        // Expire objects if configured
1575        if self.config.expire_on_commit {
1576            for tracked in self.identity_map.values_mut() {
1577                if tracked.state == ObjectState::Persistent {
1578                    tracked.state = ObjectState::Expired;
1579                }
1580            }
1581        }
1582
1583        // Fire after_commit event
1584        if let Err(e) = self.event_callbacks.fire(SessionEvent::AfterCommit) {
1585            return Outcome::Err(e);
1586        }
1587
1588        Outcome::Ok(())
1589    }
1590
1591    /// Rollback the current transaction.
1592    pub async fn rollback(&mut self, cx: &Cx) -> Outcome<(), Error> {
1593        if self.in_transaction {
1594            match self.connection.execute(cx, "ROLLBACK", &[]).await {
1595                Outcome::Ok(_) => {
1596                    self.in_transaction = false;
1597                }
1598                Outcome::Err(e) => return Outcome::Err(e),
1599                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1600                Outcome::Panicked(p) => return Outcome::Panicked(p),
1601            }
1602        }
1603
1604        // Clear pending operations
1605        self.pending_new.clear();
1606        self.pending_delete.clear();
1607        self.pending_dirty.clear();
1608
1609        // Revert objects to original state or remove new ones
1610        let mut to_remove = Vec::new();
1611        for (key, tracked) in &mut self.identity_map {
1612            match tracked.state {
1613                ObjectState::New => {
1614                    to_remove.push(*key);
1615                }
1616                ObjectState::Deleted => {
1617                    tracked.state = ObjectState::Persistent;
1618                }
1619                _ => {}
1620            }
1621        }
1622
1623        for key in to_remove {
1624            self.identity_map.remove(&key);
1625        }
1626
1627        // Fire after_rollback event
1628        if let Err(e) = self.event_callbacks.fire(SessionEvent::AfterRollback) {
1629            return Outcome::Err(e);
1630        }
1631
1632        Outcome::Ok(())
1633    }
1634
1635    // ========================================================================
1636    // Lazy Loading
1637    // ========================================================================
1638
1639    /// Load a single lazy relationship.
1640    ///
1641    /// Fetches the related object from the database and caches it in the Lazy wrapper.
1642    /// If the relationship has already been loaded, returns the cached value.
1643    ///
1644    /// # Example
1645    ///
1646    /// ```ignore
1647    /// session.load_lazy(&hero.team, &cx).await?;
1648    /// let team = hero.team.get(); // Now available
1649    /// ```
1650    #[tracing::instrument(level = "debug", skip(self, lazy, cx))]
1651    pub async fn load_lazy<
1652        T: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
1653    >(
1654        &mut self,
1655        lazy: &Lazy<T>,
1656        cx: &Cx,
1657    ) -> Outcome<bool, Error> {
1658        tracing::debug!(
1659            model = std::any::type_name::<T>(),
1660            fk = ?lazy.fk(),
1661            already_loaded = lazy.is_loaded(),
1662            "Loading lazy relationship"
1663        );
1664
1665        // If already loaded, return success
1666        if lazy.is_loaded() {
1667            tracing::trace!("Already loaded");
1668            return Outcome::Ok(lazy.get().is_some());
1669        }
1670
1671        // If no FK, set as empty and return
1672        let Some(fk) = lazy.fk() else {
1673            let _ = lazy.set_loaded(None);
1674            return Outcome::Ok(false);
1675        };
1676
1677        // Fetch from database using get()
1678        let obj = match self.get::<T>(cx, fk.clone()).await {
1679            Outcome::Ok(obj) => obj,
1680            Outcome::Err(e) => return Outcome::Err(e),
1681            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1682            Outcome::Panicked(p) => return Outcome::Panicked(p),
1683        };
1684
1685        let found = obj.is_some();
1686
1687        // Cache the result
1688        let _ = lazy.set_loaded(obj);
1689
1690        tracing::debug!(found = found, "Lazy load complete");
1691
1692        Outcome::Ok(found)
1693    }
1694
1695    /// Batch load lazy relationships for multiple objects.
1696    ///
1697    /// This method collects all FK values, executes a single query, and populates
1698    /// each Lazy field. This prevents the N+1 query problem when iterating over
1699    /// a collection and accessing lazy relationships.
1700    ///
1701    /// # Example
1702    ///
1703    /// ```ignore
1704    /// // Load 100 heroes
1705    /// let mut heroes = session.query::<Hero>().all().await?;
1706    ///
1707    /// // Without batch loading: 100 queries (N+1 problem)
1708    /// // With batch loading: 1 query
1709    /// session.load_many(&cx, &mut heroes, |h| &h.team).await?;
1710    ///
1711    /// // All teams now loaded
1712    /// for hero in &heroes {
1713    ///     if let Some(team) = hero.team.get() {
1714    ///         println!("{} is on {}", hero.name, team.name);
1715    ///     }
1716    /// }
1717    /// ```
1718    #[tracing::instrument(level = "debug", skip(self, cx, objects, accessor))]
1719    pub async fn load_many<P, T, F>(
1720        &mut self,
1721        cx: &Cx,
1722        objects: &[P],
1723        accessor: F,
1724    ) -> Outcome<usize, Error>
1725    where
1726        P: Model + 'static,
1727        T: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
1728        F: Fn(&P) -> &Lazy<T>,
1729    {
1730        // Collect all FK values that need loading
1731        let mut fk_values: Vec<Value> = Vec::new();
1732        let mut fk_indices: Vec<usize> = Vec::new();
1733
1734        for (idx, obj) in objects.iter().enumerate() {
1735            let lazy = accessor(obj);
1736            if !lazy.is_loaded() && !lazy.is_empty() {
1737                if let Some(fk) = lazy.fk() {
1738                    fk_values.push(fk.clone());
1739                    fk_indices.push(idx);
1740                }
1741            }
1742        }
1743
1744        let fk_count = fk_values.len();
1745        tracing::info!(
1746            parent_model = std::any::type_name::<P>(),
1747            related_model = std::any::type_name::<T>(),
1748            parent_count = objects.len(),
1749            fk_count = fk_count,
1750            "Batch loading lazy relationships"
1751        );
1752
1753        if fk_values.is_empty() {
1754            // Nothing to load - mark all empty/loaded Lazy fields
1755            for obj in objects {
1756                let lazy = accessor(obj);
1757                if !lazy.is_loaded() && lazy.is_empty() {
1758                    let _ = lazy.set_loaded(None);
1759                }
1760            }
1761            return Outcome::Ok(0);
1762        }
1763
1764        // Build query with IN clause
1765        let pk_col = T::PRIMARY_KEY.first().unwrap_or(&"id");
1766        let placeholders: Vec<String> = (1..=fk_values.len()).map(|i| format!("${}", i)).collect();
1767        let sql = format!(
1768            "SELECT * FROM \"{}\" WHERE \"{}\" IN ({})",
1769            T::TABLE_NAME,
1770            pk_col,
1771            placeholders.join(", ")
1772        );
1773
1774        let rows = match self.connection.query(cx, &sql, &fk_values).await {
1775            Outcome::Ok(rows) => rows,
1776            Outcome::Err(e) => return Outcome::Err(e),
1777            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1778            Outcome::Panicked(p) => return Outcome::Panicked(p),
1779        };
1780
1781        // Convert rows to objects and build PK hash -> object lookup
1782        let mut lookup: HashMap<u64, T> = HashMap::new();
1783        for row in &rows {
1784            match T::from_row(row) {
1785                Ok(obj) => {
1786                    let pk_values = obj.primary_key_value();
1787                    let pk_hash = hash_values(&pk_values);
1788
1789                    // Add to session identity map
1790                    let key = ObjectKey::from_pk::<T>(&pk_values);
1791
1792                    // Extract column data from the model while we have the concrete type
1793                    let row_data = obj.to_row();
1794                    let column_names: Vec<&'static str> =
1795                        row_data.iter().map(|(name, _)| *name).collect();
1796                    let values: Vec<Value> = row_data.into_iter().map(|(_, v)| v).collect();
1797
1798                    // Serialize values for dirty checking (must match format used in flush)
1799                    let serialized = serde_json::to_vec(&values).ok();
1800
1801                    let tracked = TrackedObject {
1802                        object: Box::new(obj.clone()),
1803                        original_state: serialized,
1804                        state: ObjectState::Persistent,
1805                        table_name: T::TABLE_NAME,
1806                        column_names,
1807                        values,
1808                        pk_columns: T::PRIMARY_KEY.to_vec(),
1809                        pk_values: pk_values.clone(),
1810                        expired_attributes: None,
1811                    };
1812                    self.identity_map.insert(key, tracked);
1813
1814                    // Add to lookup
1815                    lookup.insert(pk_hash, obj);
1816                }
1817                Err(_) => continue,
1818            }
1819        }
1820
1821        // Populate each Lazy field
1822        let mut loaded_count = 0;
1823        for obj in objects {
1824            let lazy = accessor(obj);
1825            if !lazy.is_loaded() {
1826                if let Some(fk) = lazy.fk() {
1827                    let fk_hash = hash_values(std::slice::from_ref(fk));
1828                    let related = lookup.get(&fk_hash).cloned();
1829                    let found = related.is_some();
1830                    let _ = lazy.set_loaded(related);
1831                    if found {
1832                        loaded_count += 1;
1833                    }
1834                } else {
1835                    let _ = lazy.set_loaded(None);
1836                }
1837            }
1838        }
1839
1840        tracing::debug!(
1841            query_count = 1,
1842            loaded_count = loaded_count,
1843            "Batch load complete"
1844        );
1845
1846        Outcome::Ok(loaded_count)
1847    }
1848
1849    /// Batch load many-to-many relationships for multiple parent objects.
1850    ///
1851    /// This method loads related objects via a link table in a single query,
1852    /// avoiding the N+1 problem for many-to-many relationships.
1853    ///
1854    /// # Example
1855    ///
1856    /// ```ignore
1857    /// // Load 100 heroes
1858    /// let mut heroes = session.query::<Hero>().all().await?;
1859    ///
1860    /// // Without batch loading: 100 queries (N+1 problem)
1861    /// // With batch loading: 1 query via JOIN
1862    /// let link_info = LinkTableInfo::new("hero_powers", "hero_id", "power_id");
1863    /// session.load_many_to_many(&cx, &mut heroes, |h| &mut h.powers, |h| h.id.unwrap(), &link_info).await?;
1864    ///
1865    /// // All powers now loaded
1866    /// for hero in &heroes {
1867    ///     if let Some(powers) = hero.powers.get() {
1868    ///         println!("{} has {} powers", hero.name, powers.len());
1869    ///     }
1870    /// }
1871    /// ```
1872    #[tracing::instrument(level = "debug", skip(self, cx, objects, accessor, parent_pk))]
1873    pub async fn load_many_to_many<P, Child, FA, FP>(
1874        &mut self,
1875        cx: &Cx,
1876        objects: &mut [P],
1877        accessor: FA,
1878        parent_pk: FP,
1879        link_table: &sqlmodel_core::LinkTableInfo,
1880    ) -> Outcome<usize, Error>
1881    where
1882        P: Model + 'static,
1883        Child: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
1884        FA: Fn(&mut P) -> &mut sqlmodel_core::RelatedMany<Child>,
1885        FP: Fn(&P) -> Value,
1886    {
1887        // Collect all parent PK values
1888        let pks: Vec<Value> = objects.iter().map(&parent_pk).collect();
1889
1890        tracing::info!(
1891            parent_model = std::any::type_name::<P>(),
1892            related_model = std::any::type_name::<Child>(),
1893            parent_count = pks.len(),
1894            link_table = link_table.table_name,
1895            "Batch loading many-to-many relationships"
1896        );
1897
1898        if pks.is_empty() {
1899            return Outcome::Ok(0);
1900        }
1901
1902        // Build query with JOIN through link table:
1903        // SELECT child.*, link.local_column as __parent_pk
1904        // FROM child
1905        // JOIN link ON child.pk = link.remote_column
1906        // WHERE link.local_column IN (...)
1907        let child_pk_col = Child::PRIMARY_KEY.first().unwrap_or(&"id");
1908        let placeholders: Vec<String> = (1..=pks.len()).map(|i| format!("${}", i)).collect();
1909        let sql = format!(
1910            "SELECT \"{}\".*, \"{}\".\"{}\" AS __parent_pk FROM \"{}\" \
1911             JOIN \"{}\" ON \"{}\".\"{}\" = \"{}\".\"{}\" \
1912             WHERE \"{}\".\"{}\" IN ({})",
1913            Child::TABLE_NAME,
1914            link_table.table_name,
1915            link_table.local_column,
1916            Child::TABLE_NAME,
1917            link_table.table_name,
1918            Child::TABLE_NAME,
1919            child_pk_col,
1920            link_table.table_name,
1921            link_table.remote_column,
1922            link_table.table_name,
1923            link_table.local_column,
1924            placeholders.join(", ")
1925        );
1926
1927        tracing::trace!(sql = %sql, "Many-to-many batch SQL");
1928
1929        let rows = match self.connection.query(cx, &sql, &pks).await {
1930            Outcome::Ok(rows) => rows,
1931            Outcome::Err(e) => return Outcome::Err(e),
1932            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1933            Outcome::Panicked(p) => return Outcome::Panicked(p),
1934        };
1935
1936        // Group children by parent PK
1937        let mut by_parent: HashMap<u64, Vec<Child>> = HashMap::new();
1938        for row in &rows {
1939            // Extract the parent PK from the __parent_pk alias
1940            let parent_pk_value: Value = match row.get_by_name("__parent_pk") {
1941                Some(v) => v.clone(),
1942                None => continue,
1943            };
1944            let parent_pk_hash = hash_values(std::slice::from_ref(&parent_pk_value));
1945
1946            // Parse the child model
1947            match Child::from_row(row) {
1948                Ok(child) => {
1949                    by_parent.entry(parent_pk_hash).or_default().push(child);
1950                }
1951                Err(_) => continue,
1952            }
1953        }
1954
1955        // Populate each RelatedMany field
1956        let mut loaded_count = 0;
1957        for obj in objects {
1958            let pk = parent_pk(obj);
1959            let pk_hash = hash_values(std::slice::from_ref(&pk));
1960            let children = by_parent.remove(&pk_hash).unwrap_or_default();
1961            let child_count = children.len();
1962
1963            let related = accessor(obj);
1964            related.set_parent_pk(pk);
1965            let _ = related.set_loaded(children);
1966            loaded_count += child_count;
1967        }
1968
1969        tracing::debug!(
1970            query_count = 1,
1971            total_children = loaded_count,
1972            "Many-to-many batch load complete"
1973        );
1974
1975        Outcome::Ok(loaded_count)
1976    }
1977
1978    /// Flush pending link/unlink operations for many-to-many relationships.
1979    ///
1980    /// This method persists pending link and unlink operations that were tracked
1981    /// via `RelatedMany::link()` and `RelatedMany::unlink()` calls.
1982    ///
1983    /// # Example
1984    ///
1985    /// ```ignore
1986    /// // Add a power to a hero
1987    /// hero.powers.link(&fly_power);
1988    ///
1989    /// // Remove a power from a hero
1990    /// hero.powers.unlink(&x_ray_vision);
1991    ///
1992    /// // Flush the link table operations
1993    /// let link_info = LinkTableInfo::new("hero_powers", "hero_id", "power_id");
1994    /// session.flush_related_many(&cx, &mut [hero], |h| &mut h.powers, |h| h.id.unwrap(), &link_info).await?;
1995    /// ```
1996    #[tracing::instrument(level = "debug", skip(self, cx, objects, accessor, parent_pk))]
1997    pub async fn flush_related_many<P, Child, FA, FP>(
1998        &mut self,
1999        cx: &Cx,
2000        objects: &mut [P],
2001        accessor: FA,
2002        parent_pk: FP,
2003        link_table: &sqlmodel_core::LinkTableInfo,
2004    ) -> Outcome<usize, Error>
2005    where
2006        P: Model + 'static,
2007        Child: Model + 'static,
2008        FA: Fn(&mut P) -> &mut sqlmodel_core::RelatedMany<Child>,
2009        FP: Fn(&P) -> Value,
2010    {
2011        let mut ops = Vec::new();
2012
2013        // Collect pending operations from all objects
2014        for obj in objects.iter_mut() {
2015            let parent_pk_value = parent_pk(obj);
2016            let related = accessor(obj);
2017
2018            // Collect pending links
2019            for child_pk_values in related.take_pending_links() {
2020                if let Some(child_pk) = child_pk_values.first() {
2021                    ops.push(LinkTableOp::link(
2022                        link_table.table_name.to_string(),
2023                        link_table.local_column.to_string(),
2024                        parent_pk_value.clone(),
2025                        link_table.remote_column.to_string(),
2026                        child_pk.clone(),
2027                    ));
2028                }
2029            }
2030
2031            // Collect pending unlinks
2032            for child_pk_values in related.take_pending_unlinks() {
2033                if let Some(child_pk) = child_pk_values.first() {
2034                    ops.push(LinkTableOp::unlink(
2035                        link_table.table_name.to_string(),
2036                        link_table.local_column.to_string(),
2037                        parent_pk_value.clone(),
2038                        link_table.remote_column.to_string(),
2039                        child_pk.clone(),
2040                    ));
2041                }
2042            }
2043        }
2044
2045        if ops.is_empty() {
2046            return Outcome::Ok(0);
2047        }
2048
2049        tracing::info!(
2050            parent_model = std::any::type_name::<P>(),
2051            related_model = std::any::type_name::<Child>(),
2052            link_count = ops
2053                .iter()
2054                .filter(|o| matches!(o, LinkTableOp::Link { .. }))
2055                .count(),
2056            unlink_count = ops
2057                .iter()
2058                .filter(|o| matches!(o, LinkTableOp::Unlink { .. }))
2059                .count(),
2060            link_table = link_table.table_name,
2061            "Flushing many-to-many relationship changes"
2062        );
2063
2064        flush::execute_link_table_ops(cx, &self.connection, &ops).await
2065    }
2066
2067    // ========================================================================
2068    // Bidirectional Relationship Sync (back_populates)
2069    // ========================================================================
2070
2071    /// Relate a child to a parent with bidirectional sync.
2072    ///
2073    /// Sets the parent on the child (ManyToOne side) and adds the child to the
2074    /// parent's collection (OneToMany side) if `back_populates` is defined.
2075    ///
2076    /// # Example
2077    ///
2078    /// ```ignore
2079    /// // Hero has a ManyToOne relationship to Team (hero.team)
2080    /// // Team has a OneToMany relationship to Hero (team.heroes) with back_populates
2081    ///
2082    /// session.relate_to_one(
2083    ///     &mut hero,
2084    ///     |h| &mut h.team,
2085    ///     |h| h.team_id = team.id,  // Set FK
2086    ///     &mut team,
2087    ///     |t| &mut t.heroes,
2088    /// );
2089    /// // Now hero.team is set AND team.heroes includes hero
2090    /// ```
2091    pub fn relate_to_one<Child, Parent, FC, FP, FK>(
2092        &self,
2093        child: &mut Child,
2094        child_accessor: FC,
2095        set_fk: FK,
2096        parent: &mut Parent,
2097        parent_accessor: FP,
2098    ) where
2099        Child: Model + Clone + 'static,
2100        Parent: Model + Clone + 'static,
2101        FC: FnOnce(&mut Child) -> &mut sqlmodel_core::Related<Parent>,
2102        FP: FnOnce(&mut Parent) -> &mut sqlmodel_core::RelatedMany<Child>,
2103        FK: FnOnce(&mut Child),
2104    {
2105        // Set the forward direction: child.parent = Related::loaded(parent)
2106        let related = child_accessor(child);
2107        let _ = related.set_loaded(Some(parent.clone()));
2108
2109        // Set the FK value
2110        set_fk(child);
2111
2112        // Set the reverse direction: parent.children.link(child)
2113        let related_many = parent_accessor(parent);
2114        related_many.link(child);
2115
2116        tracing::debug!(
2117            child_model = std::any::type_name::<Child>(),
2118            parent_model = std::any::type_name::<Parent>(),
2119            "Established bidirectional ManyToOne <-> OneToMany relationship"
2120        );
2121    }
2122
2123    /// Unrelate a child from a parent with bidirectional sync.
2124    ///
2125    /// Clears the parent on the child and removes the child from the parent's collection.
2126    ///
2127    /// # Example
2128    ///
2129    /// ```ignore
2130    /// session.unrelate_from_one(
2131    ///     &mut hero,
2132    ///     |h| &mut h.team,
2133    ///     |h| h.team_id = None,  // Clear FK
2134    ///     &mut team,
2135    ///     |t| &mut t.heroes,
2136    /// );
2137    /// ```
2138    pub fn unrelate_from_one<Child, Parent, FC, FP, FK>(
2139        &self,
2140        child: &mut Child,
2141        child_accessor: FC,
2142        clear_fk: FK,
2143        parent: &mut Parent,
2144        parent_accessor: FP,
2145    ) where
2146        Child: Model + Clone + 'static,
2147        Parent: Model + Clone + 'static,
2148        FC: FnOnce(&mut Child) -> &mut sqlmodel_core::Related<Parent>,
2149        FP: FnOnce(&mut Parent) -> &mut sqlmodel_core::RelatedMany<Child>,
2150        FK: FnOnce(&mut Child),
2151    {
2152        // Clear the forward direction by assigning an empty Related
2153        let related = child_accessor(child);
2154        *related = sqlmodel_core::Related::empty();
2155
2156        // Clear the FK value
2157        clear_fk(child);
2158
2159        // Remove from the reverse direction
2160        let related_many = parent_accessor(parent);
2161        related_many.unlink(child);
2162
2163        tracing::debug!(
2164            child_model = std::any::type_name::<Child>(),
2165            parent_model = std::any::type_name::<Parent>(),
2166            "Removed bidirectional ManyToOne <-> OneToMany relationship"
2167        );
2168    }
2169
2170    /// Relate two objects in a many-to-many relationship with bidirectional sync.
2171    ///
2172    /// Adds each object to the other's collection.
2173    ///
2174    /// # Example
2175    ///
2176    /// ```ignore
2177    /// // Hero has ManyToMany to Power via hero_powers link table
2178    /// // Power has ManyToMany to Hero via hero_powers link table (back_populates)
2179    ///
2180    /// session.relate_many_to_many(
2181    ///     &mut hero,
2182    ///     |h| &mut h.powers,
2183    ///     &mut power,
2184    ///     |p| &mut p.heroes,
2185    /// );
2186    /// // Now hero.powers includes power AND power.heroes includes hero
2187    /// ```
2188    pub fn relate_many_to_many<Left, Right, FL, FR>(
2189        &self,
2190        left: &mut Left,
2191        left_accessor: FL,
2192        right: &mut Right,
2193        right_accessor: FR,
2194    ) where
2195        Left: Model + Clone + 'static,
2196        Right: Model + Clone + 'static,
2197        FL: FnOnce(&mut Left) -> &mut sqlmodel_core::RelatedMany<Right>,
2198        FR: FnOnce(&mut Right) -> &mut sqlmodel_core::RelatedMany<Left>,
2199    {
2200        // Add right to left's collection
2201        let left_coll = left_accessor(left);
2202        left_coll.link(right);
2203
2204        // Add left to right's collection (back_populates)
2205        let right_coll = right_accessor(right);
2206        right_coll.link(left);
2207
2208        tracing::debug!(
2209            left_model = std::any::type_name::<Left>(),
2210            right_model = std::any::type_name::<Right>(),
2211            "Established bidirectional ManyToMany relationship"
2212        );
2213    }
2214
2215    /// Unrelate two objects in a many-to-many relationship with bidirectional sync.
2216    ///
2217    /// Removes each object from the other's collection.
2218    pub fn unrelate_many_to_many<Left, Right, FL, FR>(
2219        &self,
2220        left: &mut Left,
2221        left_accessor: FL,
2222        right: &mut Right,
2223        right_accessor: FR,
2224    ) where
2225        Left: Model + Clone + 'static,
2226        Right: Model + Clone + 'static,
2227        FL: FnOnce(&mut Left) -> &mut sqlmodel_core::RelatedMany<Right>,
2228        FR: FnOnce(&mut Right) -> &mut sqlmodel_core::RelatedMany<Left>,
2229    {
2230        // Remove right from left's collection
2231        let left_coll = left_accessor(left);
2232        left_coll.unlink(right);
2233
2234        // Remove left from right's collection (back_populates)
2235        let right_coll = right_accessor(right);
2236        right_coll.unlink(left);
2237
2238        tracing::debug!(
2239            left_model = std::any::type_name::<Left>(),
2240            right_model = std::any::type_name::<Right>(),
2241            "Removed bidirectional ManyToMany relationship"
2242        );
2243    }
2244
2245    // ========================================================================
2246    // N+1 Query Detection
2247    // ========================================================================
2248
2249    /// Enable N+1 query detection with the specified threshold.
2250    ///
2251    /// When the number of lazy loads for a single relationship reaches the
2252    /// threshold, a warning is emitted suggesting batch loading.
2253    ///
2254    /// # Example
2255    ///
2256    /// ```ignore
2257    /// session.enable_n1_detection(3);  // Warn after 3 lazy loads
2258    ///
2259    /// // This will trigger a warning:
2260    /// for hero in &mut heroes {
2261    ///     hero.team.load(&mut session).await?;
2262    /// }
2263    ///
2264    /// // Check stats
2265    /// if let Some(stats) = session.n1_stats() {
2266    ///     println!("Potential N+1 issues: {}", stats.potential_n1);
2267    /// }
2268    /// ```
2269    pub fn enable_n1_detection(&mut self, threshold: usize) {
2270        self.n1_tracker = Some(N1QueryTracker::new().with_threshold(threshold));
2271    }
2272
2273    /// Disable N+1 query detection and clear the tracker.
2274    pub fn disable_n1_detection(&mut self) {
2275        self.n1_tracker = None;
2276    }
2277
2278    /// Check if N+1 detection is enabled.
2279    #[must_use]
2280    pub fn n1_detection_enabled(&self) -> bool {
2281        self.n1_tracker.is_some()
2282    }
2283
2284    /// Get mutable access to the N+1 tracker (for recording loads).
2285    pub fn n1_tracker_mut(&mut self) -> Option<&mut N1QueryTracker> {
2286        self.n1_tracker.as_mut()
2287    }
2288
2289    /// Get N+1 detection statistics.
2290    #[must_use]
2291    pub fn n1_stats(&self) -> Option<N1Stats> {
2292        self.n1_tracker.as_ref().map(|t| t.stats())
2293    }
2294
2295    /// Reset N+1 detection counts (call at start of new request/transaction).
2296    pub fn reset_n1_tracking(&mut self) {
2297        if let Some(tracker) = &mut self.n1_tracker {
2298            tracker.reset();
2299        }
2300    }
2301
2302    /// Record a lazy load for N+1 detection.
2303    ///
2304    /// This is called automatically by lazy loading methods.
2305    #[track_caller]
2306    pub fn record_lazy_load(&mut self, parent_type: &'static str, relationship: &'static str) {
2307        if let Some(tracker) = &mut self.n1_tracker {
2308            tracker.record_load(parent_type, relationship);
2309        }
2310    }
2311
2312    // ========================================================================
2313    // Merge (Detached Object Reattachment)
2314    // ========================================================================
2315
2316    /// Merge a detached object back into the session.
2317    ///
2318    /// This method reattaches a detached or externally-created object to the session,
2319    /// copying its state to the session-tracked instance if one exists.
2320    ///
2321    /// # Behavior
2322    ///
2323    /// 1. **If object with same PK exists in session**: Updates the tracked object
2324    ///    with values from the provided object and returns a clone of the tracked version.
2325    ///
2326    /// 2. **If `load` is true and object not in session**: Queries the database for
2327    ///    an existing row, merges the provided values onto it, and tracks it.
2328    ///
2329    /// 3. **If object not in session or DB**: Treats it as new (will INSERT on flush).
2330    ///
2331    /// # Example
2332    ///
2333    /// ```ignore
2334    /// // Object from previous session or external source
2335    /// let mut detached_user = User { id: Some(1), name: "Updated Name".into(), .. };
2336    ///
2337    /// // Merge into current session
2338    /// let attached_user = session.merge(&cx, detached_user, true).await?;
2339    ///
2340    /// // attached_user is now tracked, changes will be persisted on flush
2341    /// session.flush(&cx).await?;
2342    /// ```
2343    ///
2344    /// # Parameters
2345    ///
2346    /// - `cx`: The async context for database operations.
2347    /// - `model`: The detached model instance to merge.
2348    /// - `load`: If true, load from database when not in identity map.
2349    ///
2350    /// # Returns
2351    ///
2352    /// The session-attached version of the object. If the object was already tracked,
2353    /// returns a clone of the updated tracked object. Otherwise, returns a clone of
2354    /// the newly tracked object.
2355    #[tracing::instrument(level = "debug", skip(self, cx, model), fields(table = M::TABLE_NAME))]
2356    pub async fn merge<
2357        M: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
2358    >(
2359        &mut self,
2360        cx: &Cx,
2361        model: M,
2362        load: bool,
2363    ) -> Outcome<M, Error> {
2364        let pk_values = model.primary_key_value();
2365        let key = ObjectKey::from_model(&model);
2366
2367        tracing::debug!(
2368            pk = ?pk_values,
2369            load = load,
2370            in_identity_map = self.identity_map.contains_key(&key),
2371            "Merging object"
2372        );
2373
2374        // 1. Check identity map first
2375        if let Some(tracked) = self.identity_map.get_mut(&key) {
2376            // Skip if detached - we shouldn't merge into detached objects
2377            if tracked.state == ObjectState::Detached {
2378                tracing::debug!("Found detached object, treating as new");
2379            } else {
2380                tracing::debug!(
2381                    state = ?tracked.state,
2382                    "Found tracked object, updating with merged values"
2383                );
2384
2385                // Update the tracked object with values from the provided model
2386                let row_data = model.to_row();
2387                tracked.object = Box::new(model.clone());
2388                tracked.column_names = row_data.iter().map(|(name, _)| *name).collect();
2389                tracked.values = row_data.into_iter().map(|(_, v)| v).collect();
2390                tracked.pk_values.clone_from(&pk_values);
2391
2392                // If persistent, mark as dirty for UPDATE
2393                if tracked.state == ObjectState::Persistent && !self.pending_dirty.contains(&key) {
2394                    self.pending_dirty.push(key);
2395                }
2396
2397                // Return clone of the tracked object
2398                if let Some(obj) = tracked.object.downcast_ref::<M>() {
2399                    return Outcome::Ok(obj.clone());
2400                }
2401            }
2402        }
2403
2404        // 2. If load=true, try to fetch from database
2405        if load {
2406            // Check if we have a valid primary key (not null/default)
2407            let has_valid_pk = pk_values
2408                .iter()
2409                .all(|v| !matches!(v, Value::Null | Value::Default));
2410
2411            if has_valid_pk {
2412                tracing::debug!("Loading from database");
2413
2414                let db_result = self.get_by_pk::<M>(cx, &pk_values).await;
2415                match db_result {
2416                    Outcome::Ok(Some(_existing)) => {
2417                        // Now update the tracked object (which was added by get_by_pk)
2418                        // with the values from our model
2419                        if let Some(tracked) = self.identity_map.get_mut(&key) {
2420                            let row_data = model.to_row();
2421                            tracked.object = Box::new(model.clone());
2422                            tracked.column_names = row_data.iter().map(|(name, _)| *name).collect();
2423                            tracked.values = row_data.into_iter().map(|(_, v)| v).collect();
2424                            // pk_values stay the same from DB
2425
2426                            // Mark as dirty since we're updating with new values
2427                            if !self.pending_dirty.contains(&key) {
2428                                self.pending_dirty.push(key);
2429                            }
2430
2431                            tracing::debug!("Merged values onto DB object");
2432
2433                            if let Some(obj) = tracked.object.downcast_ref::<M>() {
2434                                return Outcome::Ok(obj.clone());
2435                            }
2436                        }
2437                    }
2438                    Outcome::Ok(None) => {
2439                        tracing::debug!("Object not found in database, treating as new");
2440                    }
2441                    Outcome::Err(e) => return Outcome::Err(e),
2442                    Outcome::Cancelled(r) => return Outcome::Cancelled(r),
2443                    Outcome::Panicked(p) => return Outcome::Panicked(p),
2444                }
2445            }
2446        }
2447
2448        // 3. Treat as new - add to session
2449        tracing::debug!("Adding as new object");
2450        self.add(&model);
2451
2452        Outcome::Ok(model)
2453    }
2454
2455    /// Merge a detached object without loading from database.
2456    ///
2457    /// This is a convenience method equivalent to `merge(cx, model, false)`.
2458    /// Use this when you know the object doesn't exist in the database or
2459    /// you don't want to query the database.
2460    ///
2461    /// # Example
2462    ///
2463    /// ```ignore
2464    /// let attached = session.merge_without_load(&cx, detached_user).await?;
2465    /// ```
2466    pub async fn merge_without_load<
2467        M: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
2468    >(
2469        &mut self,
2470        cx: &Cx,
2471        model: M,
2472    ) -> Outcome<M, Error> {
2473        self.merge(cx, model, false).await
2474    }
2475
2476    // ========================================================================
2477    // Debug Diagnostics
2478    // ========================================================================
2479
2480    /// Get count of objects pending INSERT.
2481    pub fn pending_new_count(&self) -> usize {
2482        self.pending_new.len()
2483    }
2484
2485    /// Get count of objects pending DELETE.
2486    pub fn pending_delete_count(&self) -> usize {
2487        self.pending_delete.len()
2488    }
2489
2490    /// Get count of dirty objects pending UPDATE.
2491    pub fn pending_dirty_count(&self) -> usize {
2492        self.pending_dirty.len()
2493    }
2494
2495    /// Get total tracked object count.
2496    pub fn tracked_count(&self) -> usize {
2497        self.identity_map.len()
2498    }
2499
2500    /// Whether we're in a transaction.
2501    pub fn in_transaction(&self) -> bool {
2502        self.in_transaction
2503    }
2504
2505    /// Dump session state for debugging.
2506    pub fn debug_state(&self) -> SessionDebugInfo {
2507        SessionDebugInfo {
2508            tracked: self.tracked_count(),
2509            pending_new: self.pending_new_count(),
2510            pending_delete: self.pending_delete_count(),
2511            pending_dirty: self.pending_dirty_count(),
2512            in_transaction: self.in_transaction,
2513        }
2514    }
2515
2516    // ========================================================================
2517    // Bulk Operations
2518    // ========================================================================
2519
2520    /// Bulk insert multiple model instances without object tracking.
2521    ///
2522    /// This generates a single multi-row INSERT statement and bypasses
2523    /// the identity map entirely, making it much faster for large batches.
2524    ///
2525    /// Models are inserted in chunks of `batch_size` to avoid excessively
2526    /// large SQL statements. The default batch size is 1000.
2527    ///
2528    /// Returns the total number of rows inserted.
2529    pub async fn bulk_insert<M: Model + Clone + Send + Sync + 'static>(
2530        &mut self,
2531        cx: &Cx,
2532        models: &[M],
2533    ) -> Outcome<u64, Error> {
2534        self.bulk_insert_with_batch_size(cx, models, 1000).await
2535    }
2536
2537    /// Bulk insert with a custom batch size.
2538    pub async fn bulk_insert_with_batch_size<M: Model + Clone + Send + Sync + 'static>(
2539        &mut self,
2540        cx: &Cx,
2541        models: &[M],
2542        batch_size: usize,
2543    ) -> Outcome<u64, Error> {
2544        if models.is_empty() {
2545            return Outcome::Ok(0);
2546        }
2547
2548        let batch_size = batch_size.max(1);
2549        let mut total_inserted: u64 = 0;
2550
2551        for chunk in models.chunks(batch_size) {
2552            let builder = sqlmodel_query::InsertManyBuilder::new(chunk);
2553            let (sql, params) = builder.build();
2554
2555            if sql.is_empty() {
2556                continue;
2557            }
2558
2559            match self.connection.execute(cx, &sql, &params).await {
2560                Outcome::Ok(count) => total_inserted += count,
2561                Outcome::Err(e) => return Outcome::Err(e),
2562                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
2563                Outcome::Panicked(p) => return Outcome::Panicked(p),
2564            }
2565        }
2566
2567        Outcome::Ok(total_inserted)
2568    }
2569
2570    /// Bulk update multiple model instances without individual tracking.
2571    ///
2572    /// Each model is updated individually using its primary key, but
2573    /// all updates are executed in a single transaction without going
2574    /// through the identity map or change tracking.
2575    ///
2576    /// Returns the total number of rows updated.
2577    pub async fn bulk_update<M: Model + Clone + Send + Sync + 'static>(
2578        &mut self,
2579        cx: &Cx,
2580        models: &[M],
2581    ) -> Outcome<u64, Error> {
2582        if models.is_empty() {
2583            return Outcome::Ok(0);
2584        }
2585
2586        let mut total_updated: u64 = 0;
2587
2588        for model in models {
2589            let builder = sqlmodel_query::UpdateBuilder::new(model);
2590            let (sql, params) = builder.build();
2591
2592            if sql.is_empty() {
2593                continue;
2594            }
2595
2596            match self.connection.execute(cx, &sql, &params).await {
2597                Outcome::Ok(count) => total_updated += count,
2598                Outcome::Err(e) => return Outcome::Err(e),
2599                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
2600                Outcome::Panicked(p) => return Outcome::Panicked(p),
2601            }
2602        }
2603
2604        Outcome::Ok(total_updated)
2605    }
2606}
2607
2608impl<C, M> LazyLoader<M> for Session<C>
2609where
2610    C: Connection,
2611    M: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
2612{
2613    fn get(
2614        &mut self,
2615        cx: &Cx,
2616        pk: Value,
2617    ) -> impl Future<Output = Outcome<Option<M>, Error>> + Send {
2618        Session::get(self, cx, pk)
2619    }
2620}
2621
2622/// Debug information about session state.
2623#[derive(Debug, Clone)]
2624pub struct SessionDebugInfo {
2625    /// Total tracked objects.
2626    pub tracked: usize,
2627    /// Objects pending INSERT.
2628    pub pending_new: usize,
2629    /// Objects pending DELETE.
2630    pub pending_delete: usize,
2631    /// Objects pending UPDATE.
2632    pub pending_dirty: usize,
2633    /// Whether in a transaction.
2634    pub in_transaction: bool,
2635}
2636
2637// ============================================================================
2638// Unit Tests
2639// ============================================================================
2640
2641#[cfg(test)]
2642#[allow(clippy::manual_async_fn)] // Mock trait impls must match trait signatures
2643mod tests {
2644    use super::*;
2645    use asupersync::runtime::RuntimeBuilder;
2646    use sqlmodel_core::Row;
2647    use std::sync::{Arc, Mutex};
2648
2649    #[test]
2650    fn test_session_config_defaults() {
2651        let config = SessionConfig::default();
2652        assert!(config.auto_begin);
2653        assert!(!config.auto_flush);
2654        assert!(config.expire_on_commit);
2655    }
2656
2657    #[test]
2658    fn test_object_key_hash_consistency() {
2659        let values1 = vec![Value::BigInt(42)];
2660        let values2 = vec![Value::BigInt(42)];
2661        let hash1 = hash_values(&values1);
2662        let hash2 = hash_values(&values2);
2663        assert_eq!(hash1, hash2);
2664    }
2665
2666    #[test]
2667    fn test_object_key_hash_different_values() {
2668        let values1 = vec![Value::BigInt(42)];
2669        let values2 = vec![Value::BigInt(43)];
2670        let hash1 = hash_values(&values1);
2671        let hash2 = hash_values(&values2);
2672        assert_ne!(hash1, hash2);
2673    }
2674
2675    #[test]
2676    fn test_object_key_hash_different_types() {
2677        let values1 = vec![Value::BigInt(42)];
2678        let values2 = vec![Value::Text("42".to_string())];
2679        let hash1 = hash_values(&values1);
2680        let hash2 = hash_values(&values2);
2681        assert_ne!(hash1, hash2);
2682    }
2683
2684    #[test]
2685    fn test_session_debug_info() {
2686        let info = SessionDebugInfo {
2687            tracked: 5,
2688            pending_new: 2,
2689            pending_delete: 1,
2690            pending_dirty: 0,
2691            in_transaction: true,
2692        };
2693        assert_eq!(info.tracked, 5);
2694        assert_eq!(info.pending_new, 2);
2695        assert!(info.in_transaction);
2696    }
2697
2698    fn unwrap_outcome<T: std::fmt::Debug>(outcome: Outcome<T, Error>) -> T {
2699        match outcome {
2700            Outcome::Ok(v) => v,
2701            other => std::panic::panic_any(format!("unexpected outcome: {other:?}")),
2702        }
2703    }
2704
2705    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2706    struct Team {
2707        id: Option<i64>,
2708        name: String,
2709    }
2710
2711    impl Model for Team {
2712        const TABLE_NAME: &'static str = "teams";
2713        const PRIMARY_KEY: &'static [&'static str] = &["id"];
2714
2715        fn fields() -> &'static [sqlmodel_core::FieldInfo] {
2716            &[]
2717        }
2718
2719        fn to_row(&self) -> Vec<(&'static str, Value)> {
2720            vec![
2721                ("id", self.id.map_or(Value::Null, Value::BigInt)),
2722                ("name", Value::Text(self.name.clone())),
2723            ]
2724        }
2725
2726        fn from_row(row: &Row) -> sqlmodel_core::Result<Self> {
2727            let id: i64 = row.get_named("id")?;
2728            let name: String = row.get_named("name")?;
2729            Ok(Self { id: Some(id), name })
2730        }
2731
2732        fn primary_key_value(&self) -> Vec<Value> {
2733            self.id
2734                .map_or_else(|| vec![Value::Null], |id| vec![Value::BigInt(id)])
2735        }
2736
2737        fn is_new(&self) -> bool {
2738            self.id.is_none()
2739        }
2740    }
2741
2742    #[derive(Debug, Clone, Serialize, Deserialize)]
2743    struct Hero {
2744        id: Option<i64>,
2745        team: Lazy<Team>,
2746    }
2747
2748    impl Model for Hero {
2749        const TABLE_NAME: &'static str = "heroes";
2750        const PRIMARY_KEY: &'static [&'static str] = &["id"];
2751
2752        fn fields() -> &'static [sqlmodel_core::FieldInfo] {
2753            &[]
2754        }
2755
2756        fn to_row(&self) -> Vec<(&'static str, Value)> {
2757            vec![]
2758        }
2759
2760        fn from_row(_row: &Row) -> sqlmodel_core::Result<Self> {
2761            Ok(Self {
2762                id: None,
2763                team: Lazy::empty(),
2764            })
2765        }
2766
2767        fn primary_key_value(&self) -> Vec<Value> {
2768            self.id
2769                .map_or_else(|| vec![Value::Null], |id| vec![Value::BigInt(id)])
2770        }
2771
2772        fn is_new(&self) -> bool {
2773            self.id.is_none()
2774        }
2775    }
2776
2777    #[derive(Debug, Default)]
2778    struct MockState {
2779        query_calls: usize,
2780    }
2781
2782    #[derive(Debug, Clone)]
2783    struct MockConnection {
2784        state: Arc<Mutex<MockState>>,
2785    }
2786
2787    impl sqlmodel_core::Connection for MockConnection {
2788        type Tx<'conn>
2789            = MockTransaction
2790        where
2791            Self: 'conn;
2792
2793        fn query(
2794            &self,
2795            _cx: &Cx,
2796            _sql: &str,
2797            params: &[Value],
2798        ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
2799            let params = params.to_vec();
2800            let state = Arc::clone(&self.state);
2801            async move {
2802                state.lock().expect("lock poisoned").query_calls += 1;
2803
2804                let mut rows = Vec::new();
2805                for v in params {
2806                    match v {
2807                        Value::BigInt(1) => rows.push(Row::new(
2808                            vec!["id".into(), "name".into()],
2809                            vec![Value::BigInt(1), Value::Text("Avengers".into())],
2810                        )),
2811                        Value::BigInt(2) => rows.push(Row::new(
2812                            vec!["id".into(), "name".into()],
2813                            vec![Value::BigInt(2), Value::Text("X-Men".into())],
2814                        )),
2815                        _ => {}
2816                    }
2817                }
2818
2819                Outcome::Ok(rows)
2820            }
2821        }
2822
2823        fn query_one(
2824            &self,
2825            _cx: &Cx,
2826            _sql: &str,
2827            _params: &[Value],
2828        ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
2829            async { Outcome::Ok(None) }
2830        }
2831
2832        fn execute(
2833            &self,
2834            _cx: &Cx,
2835            _sql: &str,
2836            _params: &[Value],
2837        ) -> impl Future<Output = Outcome<u64, Error>> + Send {
2838            async { Outcome::Ok(0) }
2839        }
2840
2841        fn insert(
2842            &self,
2843            _cx: &Cx,
2844            _sql: &str,
2845            _params: &[Value],
2846        ) -> impl Future<Output = Outcome<i64, Error>> + Send {
2847            async { Outcome::Ok(0) }
2848        }
2849
2850        fn batch(
2851            &self,
2852            _cx: &Cx,
2853            _statements: &[(String, Vec<Value>)],
2854        ) -> impl Future<Output = Outcome<Vec<u64>, Error>> + Send {
2855            async { Outcome::Ok(vec![]) }
2856        }
2857
2858        fn begin(&self, _cx: &Cx) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
2859            async { Outcome::Ok(MockTransaction) }
2860        }
2861
2862        fn begin_with(
2863            &self,
2864            _cx: &Cx,
2865            _isolation: sqlmodel_core::connection::IsolationLevel,
2866        ) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
2867            async { Outcome::Ok(MockTransaction) }
2868        }
2869
2870        fn prepare(
2871            &self,
2872            _cx: &Cx,
2873            _sql: &str,
2874        ) -> impl Future<Output = Outcome<sqlmodel_core::connection::PreparedStatement, Error>> + Send
2875        {
2876            async {
2877                Outcome::Ok(sqlmodel_core::connection::PreparedStatement::new(
2878                    0,
2879                    String::new(),
2880                    0,
2881                ))
2882            }
2883        }
2884
2885        fn query_prepared(
2886            &self,
2887            _cx: &Cx,
2888            _stmt: &sqlmodel_core::connection::PreparedStatement,
2889            _params: &[Value],
2890        ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
2891            async { Outcome::Ok(vec![]) }
2892        }
2893
2894        fn execute_prepared(
2895            &self,
2896            _cx: &Cx,
2897            _stmt: &sqlmodel_core::connection::PreparedStatement,
2898            _params: &[Value],
2899        ) -> impl Future<Output = Outcome<u64, Error>> + Send {
2900            async { Outcome::Ok(0) }
2901        }
2902
2903        fn ping(&self, _cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
2904            async { Outcome::Ok(()) }
2905        }
2906
2907        fn close(self, _cx: &Cx) -> impl Future<Output = sqlmodel_core::Result<()>> + Send {
2908            async { Ok(()) }
2909        }
2910    }
2911
2912    struct MockTransaction;
2913
2914    impl sqlmodel_core::connection::TransactionOps for MockTransaction {
2915        fn query(
2916            &self,
2917            _cx: &Cx,
2918            _sql: &str,
2919            _params: &[Value],
2920        ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
2921            async { Outcome::Ok(vec![]) }
2922        }
2923
2924        fn query_one(
2925            &self,
2926            _cx: &Cx,
2927            _sql: &str,
2928            _params: &[Value],
2929        ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
2930            async { Outcome::Ok(None) }
2931        }
2932
2933        fn execute(
2934            &self,
2935            _cx: &Cx,
2936            _sql: &str,
2937            _params: &[Value],
2938        ) -> impl Future<Output = Outcome<u64, Error>> + Send {
2939            async { Outcome::Ok(0) }
2940        }
2941
2942        fn savepoint(
2943            &self,
2944            _cx: &Cx,
2945            _name: &str,
2946        ) -> impl Future<Output = Outcome<(), Error>> + Send {
2947            async { Outcome::Ok(()) }
2948        }
2949
2950        fn rollback_to(
2951            &self,
2952            _cx: &Cx,
2953            _name: &str,
2954        ) -> impl Future<Output = Outcome<(), Error>> + Send {
2955            async { Outcome::Ok(()) }
2956        }
2957
2958        fn release(
2959            &self,
2960            _cx: &Cx,
2961            _name: &str,
2962        ) -> impl Future<Output = Outcome<(), Error>> + Send {
2963            async { Outcome::Ok(()) }
2964        }
2965
2966        fn commit(self, _cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
2967            async { Outcome::Ok(()) }
2968        }
2969
2970        fn rollback(self, _cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
2971            async { Outcome::Ok(()) }
2972        }
2973    }
2974
2975    #[test]
2976    fn test_load_many_single_query_and_populates_lazy() {
2977        let rt = RuntimeBuilder::current_thread()
2978            .build()
2979            .expect("create asupersync runtime");
2980        let cx = Cx::for_testing();
2981
2982        let state = Arc::new(Mutex::new(MockState::default()));
2983        let conn = MockConnection {
2984            state: Arc::clone(&state),
2985        };
2986        let mut session = Session::new(conn);
2987
2988        let heroes = vec![
2989            Hero {
2990                id: Some(1),
2991                team: Lazy::from_fk(1_i64),
2992            },
2993            Hero {
2994                id: Some(2),
2995                team: Lazy::from_fk(2_i64),
2996            },
2997            Hero {
2998                id: Some(3),
2999                team: Lazy::from_fk(1_i64),
3000            },
3001            Hero {
3002                id: Some(4),
3003                team: Lazy::empty(),
3004            },
3005            Hero {
3006                id: Some(5),
3007                team: Lazy::from_fk(999_i64),
3008            },
3009        ];
3010
3011        rt.block_on(async {
3012            let loaded = unwrap_outcome(
3013                session
3014                    .load_many::<Hero, Team, _>(&cx, &heroes, |h| &h.team)
3015                    .await,
3016            );
3017            assert_eq!(loaded, 3);
3018
3019            // Populated / cached
3020            assert!(heroes[0].team.is_loaded());
3021            assert_eq!(heroes[0].team.get().unwrap().name, "Avengers");
3022            assert_eq!(heroes[1].team.get().unwrap().name, "X-Men");
3023            assert_eq!(heroes[2].team.get().unwrap().name, "Avengers");
3024
3025            // Empty FK gets cached as loaded-none
3026            assert!(heroes[3].team.is_loaded());
3027            assert!(heroes[3].team.get().is_none());
3028
3029            // Missing object gets cached as loaded-none
3030            assert!(heroes[4].team.is_loaded());
3031            assert!(heroes[4].team.get().is_none());
3032
3033            // Identity map populated: get() should not hit the connection again
3034            let team1 = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await);
3035            assert_eq!(
3036                team1,
3037                Some(Team {
3038                    id: Some(1),
3039                    name: "Avengers".to_string()
3040                })
3041            );
3042        });
3043
3044        assert_eq!(state.lock().expect("lock poisoned").query_calls, 1);
3045    }
3046
3047    #[test]
3048    fn test_add_all_with_vec() {
3049        let state = Arc::new(Mutex::new(MockState::default()));
3050        let conn = MockConnection {
3051            state: Arc::clone(&state),
3052        };
3053        let mut session = Session::new(conn);
3054
3055        // Each object needs a unique PK for identity tracking
3056        // (objects without PKs get the same ObjectKey)
3057        let teams = vec![
3058            Team {
3059                id: Some(100),
3060                name: "Team A".to_string(),
3061            },
3062            Team {
3063                id: Some(101),
3064                name: "Team B".to_string(),
3065            },
3066            Team {
3067                id: Some(102),
3068                name: "Team C".to_string(),
3069            },
3070        ];
3071
3072        session.add_all(&teams);
3073
3074        let info = session.debug_state();
3075        assert_eq!(info.pending_new, 3);
3076        assert_eq!(info.tracked, 3);
3077    }
3078
3079    #[test]
3080    fn test_add_all_with_empty_collection() {
3081        let state = Arc::new(Mutex::new(MockState::default()));
3082        let conn = MockConnection {
3083            state: Arc::clone(&state),
3084        };
3085        let mut session = Session::new(conn);
3086
3087        let teams: Vec<Team> = vec![];
3088        session.add_all(&teams);
3089
3090        let info = session.debug_state();
3091        assert_eq!(info.pending_new, 0);
3092        assert_eq!(info.tracked, 0);
3093    }
3094
3095    #[test]
3096    fn test_add_all_with_iterator() {
3097        let state = Arc::new(Mutex::new(MockState::default()));
3098        let conn = MockConnection {
3099            state: Arc::clone(&state),
3100        };
3101        let mut session = Session::new(conn);
3102
3103        let teams = [
3104            Team {
3105                id: Some(200),
3106                name: "Team X".to_string(),
3107            },
3108            Team {
3109                id: Some(201),
3110                name: "Team Y".to_string(),
3111            },
3112        ];
3113
3114        // Use iter() explicitly
3115        session.add_all(teams.iter());
3116
3117        let info = session.debug_state();
3118        assert_eq!(info.pending_new, 2);
3119        assert_eq!(info.tracked, 2);
3120    }
3121
3122    #[test]
3123    fn test_add_all_with_slice() {
3124        let state = Arc::new(Mutex::new(MockState::default()));
3125        let conn = MockConnection {
3126            state: Arc::clone(&state),
3127        };
3128        let mut session = Session::new(conn);
3129
3130        let teams = [
3131            Team {
3132                id: Some(300),
3133                name: "Team 1".to_string(),
3134            },
3135            Team {
3136                id: Some(301),
3137                name: "Team 2".to_string(),
3138            },
3139        ];
3140
3141        session.add_all(&teams);
3142
3143        let info = session.debug_state();
3144        assert_eq!(info.pending_new, 2);
3145        assert_eq!(info.tracked, 2);
3146    }
3147
3148    // ==================== Merge Tests ====================
3149
3150    #[test]
3151    fn test_merge_new_object_without_load() {
3152        let rt = RuntimeBuilder::current_thread()
3153            .build()
3154            .expect("create asupersync runtime");
3155        let cx = Cx::for_testing();
3156
3157        let state = Arc::new(Mutex::new(MockState::default()));
3158        let conn = MockConnection {
3159            state: Arc::clone(&state),
3160        };
3161        let mut session = Session::new(conn);
3162
3163        rt.block_on(async {
3164            // Merge a new object without loading from DB
3165            let team = Team {
3166                id: Some(100),
3167                name: "New Team".to_string(),
3168            };
3169
3170            let merged = unwrap_outcome(session.merge(&cx, team.clone(), false).await);
3171
3172            // Should be the same object
3173            assert_eq!(merged.id, Some(100));
3174            assert_eq!(merged.name, "New Team");
3175
3176            // Should be tracked as new
3177            let info = session.debug_state();
3178            assert_eq!(info.pending_new, 1);
3179            assert_eq!(info.tracked, 1);
3180        });
3181
3182        // Should not have queried DB (load=false)
3183        assert_eq!(state.lock().expect("lock poisoned").query_calls, 0);
3184    }
3185
3186    #[test]
3187    fn test_merge_updates_existing_tracked_object() {
3188        let rt = RuntimeBuilder::current_thread()
3189            .build()
3190            .expect("create asupersync runtime");
3191        let cx = Cx::for_testing();
3192
3193        let state = Arc::new(Mutex::new(MockState::default()));
3194        let conn = MockConnection {
3195            state: Arc::clone(&state),
3196        };
3197        let mut session = Session::new(conn);
3198
3199        rt.block_on(async {
3200            // First add an object
3201            let original = Team {
3202                id: Some(1),
3203                name: "Original".to_string(),
3204            };
3205            session.add(&original);
3206
3207            // Now merge an updated version
3208            let updated = Team {
3209                id: Some(1),
3210                name: "Updated".to_string(),
3211            };
3212
3213            let merged = unwrap_outcome(session.merge(&cx, updated, false).await);
3214
3215            // Should have the updated name
3216            assert_eq!(merged.id, Some(1));
3217            assert_eq!(merged.name, "Updated");
3218
3219            // Should still be tracked (not duplicated)
3220            let info = session.debug_state();
3221            assert_eq!(info.tracked, 1);
3222        });
3223    }
3224
3225    #[test]
3226    fn test_merge_with_load_queries_database() {
3227        let rt = RuntimeBuilder::current_thread()
3228            .build()
3229            .expect("create asupersync runtime");
3230        let cx = Cx::for_testing();
3231
3232        let state = Arc::new(Mutex::new(MockState::default()));
3233        let conn = MockConnection {
3234            state: Arc::clone(&state),
3235        };
3236        let mut session = Session::new(conn);
3237
3238        rt.block_on(async {
3239            // Merge an object that exists in the "database" (mock returns it for id=1)
3240            let detached = Team {
3241                id: Some(1),
3242                name: "Detached Update".to_string(),
3243            };
3244
3245            let merged = unwrap_outcome(session.merge(&cx, detached, true).await);
3246
3247            // Should have the name from our detached object (merged onto DB values)
3248            assert_eq!(merged.id, Some(1));
3249            assert_eq!(merged.name, "Detached Update");
3250
3251            // Should be tracked and marked as dirty
3252            let info = session.debug_state();
3253            assert_eq!(info.tracked, 1);
3254            assert_eq!(info.pending_dirty, 1);
3255        });
3256
3257        // Should have queried DB once (load=true)
3258        assert_eq!(state.lock().expect("lock poisoned").query_calls, 1);
3259    }
3260
3261    #[test]
3262    fn test_merge_with_load_not_found_creates_new() {
3263        let rt = RuntimeBuilder::current_thread()
3264            .build()
3265            .expect("create asupersync runtime");
3266        let cx = Cx::for_testing();
3267
3268        let state = Arc::new(Mutex::new(MockState::default()));
3269        let conn = MockConnection {
3270            state: Arc::clone(&state),
3271        };
3272        let mut session = Session::new(conn);
3273
3274        rt.block_on(async {
3275            // Merge an object that doesn't exist in DB (mock returns None for id=999)
3276            let detached = Team {
3277                id: Some(999),
3278                name: "Not In DB".to_string(),
3279            };
3280
3281            let merged = unwrap_outcome(session.merge(&cx, detached, true).await);
3282
3283            // Should keep the values we provided
3284            assert_eq!(merged.id, Some(999));
3285            assert_eq!(merged.name, "Not In DB");
3286
3287            // Should be tracked as new
3288            let info = session.debug_state();
3289            assert_eq!(info.pending_new, 1);
3290            assert_eq!(info.tracked, 1);
3291        });
3292
3293        // Should have queried DB once
3294        assert_eq!(state.lock().expect("lock poisoned").query_calls, 1);
3295    }
3296
3297    #[test]
3298    fn test_merge_without_load_convenience() {
3299        let rt = RuntimeBuilder::current_thread()
3300            .build()
3301            .expect("create asupersync runtime");
3302        let cx = Cx::for_testing();
3303
3304        let state = Arc::new(Mutex::new(MockState::default()));
3305        let conn = MockConnection {
3306            state: Arc::clone(&state),
3307        };
3308        let mut session = Session::new(conn);
3309
3310        rt.block_on(async {
3311            let team = Team {
3312                id: Some(42),
3313                name: "Test".to_string(),
3314            };
3315
3316            // Use the convenience method
3317            let merged = unwrap_outcome(session.merge_without_load(&cx, team).await);
3318
3319            assert_eq!(merged.id, Some(42));
3320            assert_eq!(merged.name, "Test");
3321
3322            let info = session.debug_state();
3323            assert_eq!(info.pending_new, 1);
3324        });
3325
3326        // Should not have queried DB
3327        assert_eq!(state.lock().expect("lock poisoned").query_calls, 0);
3328    }
3329
3330    #[test]
3331    fn test_merge_null_pk_treated_as_new() {
3332        let rt = RuntimeBuilder::current_thread()
3333            .build()
3334            .expect("create asupersync runtime");
3335        let cx = Cx::for_testing();
3336
3337        let state = Arc::new(Mutex::new(MockState::default()));
3338        let conn = MockConnection {
3339            state: Arc::clone(&state),
3340        };
3341        let mut session = Session::new(conn);
3342
3343        rt.block_on(async {
3344            // Merge object with null PK (new record)
3345            let new_team = Team {
3346                id: None,
3347                name: "Brand New".to_string(),
3348            };
3349
3350            let merged = unwrap_outcome(session.merge(&cx, new_team, true).await);
3351
3352            // Should keep the null id
3353            assert_eq!(merged.id, None);
3354            assert_eq!(merged.name, "Brand New");
3355
3356            // Should be tracked as new (no DB query for null PK)
3357            let info = session.debug_state();
3358            assert_eq!(info.pending_new, 1);
3359        });
3360
3361        // Should not have queried DB for null PK
3362        assert_eq!(state.lock().expect("lock poisoned").query_calls, 0);
3363    }
3364
3365    // ==================== is_modified Tests ====================
3366
3367    #[test]
3368    fn test_is_modified_new_object_returns_true() {
3369        let state = Arc::new(Mutex::new(MockState::default()));
3370        let conn = MockConnection {
3371            state: Arc::clone(&state),
3372        };
3373        let mut session = Session::new(conn);
3374
3375        let team = Team {
3376            id: Some(100),
3377            name: "New Team".to_string(),
3378        };
3379
3380        // Add as new - should be modified
3381        session.add(&team);
3382        assert!(session.is_modified(&team));
3383    }
3384
3385    #[test]
3386    fn test_is_modified_untracked_returns_false() {
3387        let state = Arc::new(Mutex::new(MockState::default()));
3388        let conn = MockConnection {
3389            state: Arc::clone(&state),
3390        };
3391        let session = Session::<MockConnection>::new(conn);
3392
3393        let team = Team {
3394            id: Some(100),
3395            name: "Not Tracked".to_string(),
3396        };
3397
3398        // Not tracked - should not be modified
3399        assert!(!session.is_modified(&team));
3400    }
3401
3402    #[test]
3403    fn test_is_modified_after_load_returns_false() {
3404        let rt = RuntimeBuilder::current_thread()
3405            .build()
3406            .expect("create asupersync runtime");
3407        let cx = Cx::for_testing();
3408
3409        let state = Arc::new(Mutex::new(MockState::default()));
3410        let conn = MockConnection {
3411            state: Arc::clone(&state),
3412        };
3413        let mut session = Session::new(conn);
3414
3415        rt.block_on(async {
3416            // Load from DB
3417            let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3418
3419            // Fresh from DB - should not be modified
3420            assert!(!session.is_modified(&team));
3421        });
3422    }
3423
3424    #[test]
3425    fn test_is_modified_after_mark_dirty_returns_true() {
3426        let rt = RuntimeBuilder::current_thread()
3427            .build()
3428            .expect("create asupersync runtime");
3429        let cx = Cx::for_testing();
3430
3431        let state = Arc::new(Mutex::new(MockState::default()));
3432        let conn = MockConnection {
3433            state: Arc::clone(&state),
3434        };
3435        let mut session = Session::new(conn);
3436
3437        rt.block_on(async {
3438            // Load from DB
3439            let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3440            assert!(!session.is_modified(&team));
3441
3442            // Modify and mark dirty
3443            let mut modified_team = team.clone();
3444            modified_team.name = "Modified Name".to_string();
3445            session.mark_dirty(&modified_team);
3446
3447            // Should now be modified
3448            assert!(session.is_modified(&modified_team));
3449        });
3450    }
3451
3452    #[test]
3453    fn test_is_modified_deleted_returns_true() {
3454        let rt = RuntimeBuilder::current_thread()
3455            .build()
3456            .expect("create asupersync runtime");
3457        let cx = Cx::for_testing();
3458
3459        let state = Arc::new(Mutex::new(MockState::default()));
3460        let conn = MockConnection {
3461            state: Arc::clone(&state),
3462        };
3463        let mut session = Session::new(conn);
3464
3465        rt.block_on(async {
3466            // Load from DB
3467            let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3468            assert!(!session.is_modified(&team));
3469
3470            // Delete
3471            session.delete(&team);
3472
3473            // Should be modified (pending delete)
3474            assert!(session.is_modified(&team));
3475        });
3476    }
3477
3478    #[test]
3479    fn test_is_modified_detached_returns_false() {
3480        let rt = RuntimeBuilder::current_thread()
3481            .build()
3482            .expect("create asupersync runtime");
3483        let cx = Cx::for_testing();
3484
3485        let state = Arc::new(Mutex::new(MockState::default()));
3486        let conn = MockConnection {
3487            state: Arc::clone(&state),
3488        };
3489        let mut session = Session::new(conn);
3490
3491        rt.block_on(async {
3492            // Load from DB
3493            let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3494
3495            // Detach
3496            session.expunge(&team);
3497
3498            // Detached objects aren't modified in session context
3499            assert!(!session.is_modified(&team));
3500        });
3501    }
3502
3503    #[test]
3504    fn test_object_state_returns_correct_state() {
3505        let rt = RuntimeBuilder::current_thread()
3506            .build()
3507            .expect("create asupersync runtime");
3508        let cx = Cx::for_testing();
3509
3510        let state = Arc::new(Mutex::new(MockState::default()));
3511        let conn = MockConnection {
3512            state: Arc::clone(&state),
3513        };
3514        let mut session = Session::new(conn);
3515
3516        // Untracked object
3517        let untracked = Team {
3518            id: Some(999),
3519            name: "Untracked".to_string(),
3520        };
3521        assert_eq!(session.object_state(&untracked), None);
3522
3523        // New object
3524        let new_team = Team {
3525            id: Some(100),
3526            name: "New".to_string(),
3527        };
3528        session.add(&new_team);
3529        assert_eq!(session.object_state(&new_team), Some(ObjectState::New));
3530
3531        rt.block_on(async {
3532            // Persistent object
3533            let persistent = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3534            assert_eq!(
3535                session.object_state(&persistent),
3536                Some(ObjectState::Persistent)
3537            );
3538
3539            // Deleted object
3540            session.delete(&persistent);
3541            assert_eq!(
3542                session.object_state(&persistent),
3543                Some(ObjectState::Deleted)
3544            );
3545        });
3546    }
3547
3548    #[test]
3549    fn test_modified_attributes_returns_changed_columns() {
3550        let rt = RuntimeBuilder::current_thread()
3551            .build()
3552            .expect("create asupersync runtime");
3553        let cx = Cx::for_testing();
3554
3555        let state = Arc::new(Mutex::new(MockState::default()));
3556        let conn = MockConnection {
3557            state: Arc::clone(&state),
3558        };
3559        let mut session = Session::new(conn);
3560
3561        rt.block_on(async {
3562            // Load from DB
3563            let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3564
3565            // No modifications yet
3566            let modified = session.modified_attributes(&team);
3567            assert!(modified.is_empty());
3568
3569            // Modify and mark dirty
3570            let mut modified_team = team.clone();
3571            modified_team.name = "Changed Name".to_string();
3572            session.mark_dirty(&modified_team);
3573
3574            // Should show 'name' as modified
3575            let modified = session.modified_attributes(&modified_team);
3576            assert!(modified.contains(&"name"));
3577        });
3578    }
3579
3580    #[test]
3581    fn test_modified_attributes_untracked_returns_empty() {
3582        let state = Arc::new(Mutex::new(MockState::default()));
3583        let conn = MockConnection {
3584            state: Arc::clone(&state),
3585        };
3586        let session = Session::<MockConnection>::new(conn);
3587
3588        let team = Team {
3589            id: Some(100),
3590            name: "Not Tracked".to_string(),
3591        };
3592
3593        let modified = session.modified_attributes(&team);
3594        assert!(modified.is_empty());
3595    }
3596
3597    #[test]
3598    fn test_modified_attributes_new_returns_empty() {
3599        let state = Arc::new(Mutex::new(MockState::default()));
3600        let conn = MockConnection {
3601            state: Arc::clone(&state),
3602        };
3603        let mut session = Session::new(conn);
3604
3605        let team = Team {
3606            id: Some(100),
3607            name: "New".to_string(),
3608        };
3609        session.add(&team);
3610
3611        // New objects don't have original values to compare
3612        let modified = session.modified_attributes(&team);
3613        assert!(modified.is_empty());
3614    }
3615
3616    // ==================== Expire Tests ====================
3617
3618    #[test]
3619    fn test_expire_marks_object_as_expired() {
3620        let rt = RuntimeBuilder::current_thread()
3621            .build()
3622            .expect("create asupersync runtime");
3623        let cx = Cx::for_testing();
3624
3625        let state = Arc::new(Mutex::new(MockState::default()));
3626        let conn = MockConnection {
3627            state: Arc::clone(&state),
3628        };
3629        let mut session = Session::new(conn);
3630
3631        rt.block_on(async {
3632            // Get an object from DB (creates Persistent state)
3633            let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await);
3634            assert!(team.is_some());
3635            let team = team.unwrap();
3636
3637            // Verify it's not expired initially
3638            assert!(!session.is_expired(&team));
3639            assert_eq!(session.object_state(&team), Some(ObjectState::Persistent));
3640
3641            // Expire all attributes
3642            session.expire(&team, None);
3643
3644            // Should now be expired
3645            assert!(session.is_expired(&team));
3646            assert_eq!(session.object_state(&team), Some(ObjectState::Expired));
3647        });
3648    }
3649
3650    #[test]
3651    fn test_expire_specific_attributes() {
3652        let rt = RuntimeBuilder::current_thread()
3653            .build()
3654            .expect("create asupersync runtime");
3655        let cx = Cx::for_testing();
3656
3657        let state = Arc::new(Mutex::new(MockState::default()));
3658        let conn = MockConnection {
3659            state: Arc::clone(&state),
3660        };
3661        let mut session = Session::new(conn);
3662
3663        rt.block_on(async {
3664            // Get an object from DB
3665            let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3666
3667            // Expire specific attributes
3668            session.expire(&team, Some(&["name"]));
3669
3670            // Should be expired
3671            assert!(session.is_expired(&team));
3672
3673            // Check expired attributes
3674            let expired = session.expired_attributes(&team);
3675            assert!(expired.is_some());
3676            let expired_set = expired.unwrap();
3677            assert!(expired_set.is_some());
3678            assert!(expired_set.unwrap().contains("name"));
3679        });
3680    }
3681
3682    #[test]
3683    fn test_expire_all_marks_all_objects_expired() {
3684        let rt = RuntimeBuilder::current_thread()
3685            .build()
3686            .expect("create asupersync runtime");
3687        let cx = Cx::for_testing();
3688
3689        let state = Arc::new(Mutex::new(MockState::default()));
3690        let conn = MockConnection {
3691            state: Arc::clone(&state),
3692        };
3693        let mut session = Session::new(conn);
3694
3695        rt.block_on(async {
3696            // Get multiple objects from DB
3697            let team1 = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3698            let team2 = unwrap_outcome(session.get::<Team>(&cx, 2_i64).await).unwrap();
3699
3700            // Verify neither is expired
3701            assert!(!session.is_expired(&team1));
3702            assert!(!session.is_expired(&team2));
3703
3704            // Expire all
3705            session.expire_all();
3706
3707            // Both should be expired
3708            assert!(session.is_expired(&team1));
3709            assert!(session.is_expired(&team2));
3710        });
3711    }
3712
3713    #[test]
3714    fn test_expire_does_not_affect_new_objects() {
3715        let state = Arc::new(Mutex::new(MockState::default()));
3716        let conn = MockConnection {
3717            state: Arc::clone(&state),
3718        };
3719        let mut session = Session::new(conn);
3720
3721        // Add a new object
3722        let team = Team {
3723            id: Some(100),
3724            name: "New Team".to_string(),
3725        };
3726        session.add(&team);
3727
3728        // Try to expire it
3729        session.expire(&team, None);
3730
3731        // Should still be New, not Expired
3732        assert_eq!(session.object_state(&team), Some(ObjectState::New));
3733        assert!(!session.is_expired(&team));
3734    }
3735
3736    #[test]
3737    fn test_expired_object_reloads_on_get() {
3738        let rt = RuntimeBuilder::current_thread()
3739            .build()
3740            .expect("create asupersync runtime");
3741        let cx = Cx::for_testing();
3742
3743        let state = Arc::new(Mutex::new(MockState::default()));
3744        let conn = MockConnection {
3745            state: Arc::clone(&state),
3746        };
3747        let mut session = Session::new(conn);
3748
3749        rt.block_on(async {
3750            // Get an object (query 1)
3751            let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3752            assert_eq!(team.name, "Avengers");
3753
3754            // Get again - should use cache (no additional query)
3755            let team2 = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3756            assert_eq!(team2.name, "Avengers");
3757
3758            // Verify only 1 query so far
3759            {
3760                let s = state.lock().expect("lock poisoned");
3761                assert_eq!(s.query_calls, 1);
3762            }
3763
3764            // Expire the object
3765            session.expire(&team, None);
3766
3767            // Get again - should reload from DB (query 2)
3768            let team3 = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3769            assert_eq!(team3.name, "Avengers");
3770
3771            // Verify a second query was made
3772            {
3773                let s = state.lock().expect("lock poisoned");
3774                assert_eq!(s.query_calls, 2);
3775            }
3776
3777            // Should no longer be expired after reload
3778            assert!(!session.is_expired(&team3));
3779            assert_eq!(session.object_state(&team3), Some(ObjectState::Persistent));
3780        });
3781    }
3782
3783    #[test]
3784    fn test_is_expired_returns_false_for_untracked() {
3785        let state = Arc::new(Mutex::new(MockState::default()));
3786        let conn = MockConnection {
3787            state: Arc::clone(&state),
3788        };
3789        let session = Session::<MockConnection>::new(conn);
3790
3791        let team = Team {
3792            id: Some(999),
3793            name: "Not Tracked".to_string(),
3794        };
3795
3796        // Should return false for untracked objects
3797        assert!(!session.is_expired(&team));
3798    }
3799
3800    #[test]
3801    fn test_expired_attributes_returns_none_for_persistent() {
3802        let rt = RuntimeBuilder::current_thread()
3803            .build()
3804            .expect("create asupersync runtime");
3805        let cx = Cx::for_testing();
3806
3807        let state = Arc::new(Mutex::new(MockState::default()));
3808        let conn = MockConnection {
3809            state: Arc::clone(&state),
3810        };
3811        let mut session = Session::new(conn);
3812
3813        rt.block_on(async {
3814            // Get an object (Persistent state)
3815            let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3816
3817            // Should return None for non-expired objects
3818            let expired = session.expired_attributes(&team);
3819            assert!(expired.is_none());
3820        });
3821    }
3822}