1pub mod change_tracker;
43pub mod flush;
44pub mod identity_map;
45pub mod n1_detection;
46pub mod unit_of_work;
47
48pub use change_tracker::{ChangeTracker, ObjectSnapshot};
49pub use flush::{
50 FlushOrderer, FlushPlan, FlushResult, LinkTableOp, PendingOp, execute_link_table_ops,
51};
52pub use identity_map::{IdentityMap, ModelReadGuard, ModelRef, ModelWriteGuard, WeakIdentityMap};
53pub use n1_detection::{CallSite, N1DetectionScope, N1QueryTracker, N1Stats};
54pub use unit_of_work::{PendingCounts, UnitOfWork, UowError};
55
56use asupersync::{Cx, Outcome};
57use serde::{Deserialize, Serialize};
58use sqlmodel_core::{Connection, Error, Lazy, LazyLoader, Model, Value};
59use std::any::{Any, TypeId};
60use std::collections::HashMap;
61use std::future::Future;
62use std::hash::{Hash, Hasher};
63
64type SessionEventFn = Box<dyn FnMut() -> Result<(), Error> + Send>;
73
74#[derive(Default)]
79pub struct SessionEventCallbacks {
80 before_flush: Vec<SessionEventFn>,
81 after_flush: Vec<SessionEventFn>,
82 before_commit: Vec<SessionEventFn>,
83 after_commit: Vec<SessionEventFn>,
84 after_rollback: Vec<SessionEventFn>,
85}
86
87impl std::fmt::Debug for SessionEventCallbacks {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.debug_struct("SessionEventCallbacks")
90 .field("before_flush", &self.before_flush.len())
91 .field("after_flush", &self.after_flush.len())
92 .field("before_commit", &self.before_commit.len())
93 .field("after_commit", &self.after_commit.len())
94 .field("after_rollback", &self.after_rollback.len())
95 .finish()
96 }
97}
98
99impl SessionEventCallbacks {
100 #[allow(clippy::result_large_err)]
101 fn fire(&mut self, event: SessionEvent) -> Result<(), Error> {
102 let callbacks = match event {
103 SessionEvent::BeforeFlush => &mut self.before_flush,
104 SessionEvent::AfterFlush => &mut self.after_flush,
105 SessionEvent::BeforeCommit => &mut self.before_commit,
106 SessionEvent::AfterCommit => &mut self.after_commit,
107 SessionEvent::AfterRollback => &mut self.after_rollback,
108 };
109 for cb in callbacks.iter_mut() {
110 cb()?;
111 }
112 Ok(())
113 }
114}
115
116#[derive(Debug, Clone, Copy, PartialEq, Eq)]
118pub enum SessionEvent {
119 BeforeFlush,
121 AfterFlush,
123 BeforeCommit,
125 AfterCommit,
127 AfterRollback,
129}
130
131#[derive(Debug, Clone)]
137pub struct SessionConfig {
138 pub auto_begin: bool,
140 pub auto_flush: bool,
142 pub expire_on_commit: bool,
144}
145
146impl Default for SessionConfig {
147 fn default() -> Self {
148 Self {
149 auto_begin: true,
150 auto_flush: false,
151 expire_on_commit: true,
152 }
153 }
154}
155
156#[derive(Debug, Clone, Default)]
158pub struct GetOptions {
159 pub with_for_update: bool,
161 pub skip_locked: bool,
163 pub nowait: bool,
165}
166
167impl GetOptions {
168 #[must_use]
170 pub fn new() -> Self {
171 Self::default()
172 }
173
174 #[must_use]
176 pub fn with_for_update(mut self, value: bool) -> Self {
177 self.with_for_update = value;
178 self
179 }
180
181 #[must_use]
183 pub fn skip_locked(mut self, value: bool) -> Self {
184 self.skip_locked = value;
185 self
186 }
187
188 #[must_use]
190 pub fn nowait(mut self, value: bool) -> Self {
191 self.nowait = value;
192 self
193 }
194}
195
196#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
202pub struct ObjectKey {
203 type_id: TypeId,
205 pk_hash: u64,
207}
208
209impl ObjectKey {
210 pub fn from_model<M: Model + 'static>(obj: &M) -> Self {
212 let pk_values = obj.primary_key_value();
213 Self {
214 type_id: TypeId::of::<M>(),
215 pk_hash: hash_values(&pk_values),
216 }
217 }
218
219 pub fn from_pk<M: Model + 'static>(pk: &[Value]) -> Self {
221 Self {
222 type_id: TypeId::of::<M>(),
223 pk_hash: hash_values(pk),
224 }
225 }
226
227 pub fn pk_hash(&self) -> u64 {
229 self.pk_hash
230 }
231
232 pub fn type_id(&self) -> TypeId {
234 self.type_id
235 }
236}
237
238fn hash_values(values: &[Value]) -> u64 {
240 use std::collections::hash_map::DefaultHasher;
241 let mut hasher = DefaultHasher::new();
242 for v in values {
243 match v {
245 Value::Null => 0u8.hash(&mut hasher),
246 Value::Bool(b) => {
247 1u8.hash(&mut hasher);
248 b.hash(&mut hasher);
249 }
250 Value::TinyInt(i) => {
251 2u8.hash(&mut hasher);
252 i.hash(&mut hasher);
253 }
254 Value::SmallInt(i) => {
255 3u8.hash(&mut hasher);
256 i.hash(&mut hasher);
257 }
258 Value::Int(i) => {
259 4u8.hash(&mut hasher);
260 i.hash(&mut hasher);
261 }
262 Value::BigInt(i) => {
263 5u8.hash(&mut hasher);
264 i.hash(&mut hasher);
265 }
266 Value::Float(f) => {
267 6u8.hash(&mut hasher);
268 f.to_bits().hash(&mut hasher);
269 }
270 Value::Double(f) => {
271 7u8.hash(&mut hasher);
272 f.to_bits().hash(&mut hasher);
273 }
274 Value::Decimal(s) => {
275 8u8.hash(&mut hasher);
276 s.hash(&mut hasher);
277 }
278 Value::Text(s) => {
279 9u8.hash(&mut hasher);
280 s.hash(&mut hasher);
281 }
282 Value::Bytes(b) => {
283 10u8.hash(&mut hasher);
284 b.hash(&mut hasher);
285 }
286 Value::Date(d) => {
287 11u8.hash(&mut hasher);
288 d.hash(&mut hasher);
289 }
290 Value::Time(t) => {
291 12u8.hash(&mut hasher);
292 t.hash(&mut hasher);
293 }
294 Value::Timestamp(ts) => {
295 13u8.hash(&mut hasher);
296 ts.hash(&mut hasher);
297 }
298 Value::TimestampTz(ts) => {
299 14u8.hash(&mut hasher);
300 ts.hash(&mut hasher);
301 }
302 Value::Uuid(u) => {
303 15u8.hash(&mut hasher);
304 u.hash(&mut hasher);
305 }
306 Value::Json(j) => {
307 16u8.hash(&mut hasher);
308 j.to_string().hash(&mut hasher);
310 }
311 Value::Array(arr) => {
312 17u8.hash(&mut hasher);
313 arr.len().hash(&mut hasher);
315 for item in arr {
316 hash_value(item, &mut hasher);
317 }
318 }
319 Value::Default => {
320 18u8.hash(&mut hasher);
321 }
322 }
323 }
324 hasher.finish()
325}
326
327fn hash_value(v: &Value, hasher: &mut impl Hasher) {
329 match v {
330 Value::Null => 0u8.hash(hasher),
331 Value::Bool(b) => {
332 1u8.hash(hasher);
333 b.hash(hasher);
334 }
335 Value::TinyInt(i) => {
336 2u8.hash(hasher);
337 i.hash(hasher);
338 }
339 Value::SmallInt(i) => {
340 3u8.hash(hasher);
341 i.hash(hasher);
342 }
343 Value::Int(i) => {
344 4u8.hash(hasher);
345 i.hash(hasher);
346 }
347 Value::BigInt(i) => {
348 5u8.hash(hasher);
349 i.hash(hasher);
350 }
351 Value::Float(f) => {
352 6u8.hash(hasher);
353 f.to_bits().hash(hasher);
354 }
355 Value::Double(f) => {
356 7u8.hash(hasher);
357 f.to_bits().hash(hasher);
358 }
359 Value::Decimal(s) => {
360 8u8.hash(hasher);
361 s.hash(hasher);
362 }
363 Value::Text(s) => {
364 9u8.hash(hasher);
365 s.hash(hasher);
366 }
367 Value::Bytes(b) => {
368 10u8.hash(hasher);
369 b.hash(hasher);
370 }
371 Value::Date(d) => {
372 11u8.hash(hasher);
373 d.hash(hasher);
374 }
375 Value::Time(t) => {
376 12u8.hash(hasher);
377 t.hash(hasher);
378 }
379 Value::Timestamp(ts) => {
380 13u8.hash(hasher);
381 ts.hash(hasher);
382 }
383 Value::TimestampTz(ts) => {
384 14u8.hash(hasher);
385 ts.hash(hasher);
386 }
387 Value::Uuid(u) => {
388 15u8.hash(hasher);
389 u.hash(hasher);
390 }
391 Value::Json(j) => {
392 16u8.hash(hasher);
393 j.to_string().hash(hasher);
394 }
395 Value::Array(arr) => {
396 17u8.hash(hasher);
397 arr.len().hash(hasher);
398 for item in arr {
399 hash_value(item, hasher);
400 }
401 }
402 Value::Default => {
403 18u8.hash(hasher);
404 }
405 }
406}
407
408#[derive(Debug, Clone, Copy, PartialEq, Eq)]
410pub enum ObjectState {
411 New,
413 Persistent,
415 Deleted,
417 Detached,
419 Expired,
421}
422
423struct TrackedObject {
425 object: Box<dyn Any + Send + Sync>,
427 original_state: Option<Vec<u8>>,
429 state: ObjectState,
431 table_name: &'static str,
433 column_names: Vec<&'static str>,
435 values: Vec<Value>,
437 pk_columns: Vec<&'static str>,
439 pk_values: Vec<Value>,
441 relationships: &'static [sqlmodel_core::RelationshipInfo],
443 expired_attributes: Option<std::collections::HashSet<String>>,
446}
447
448#[derive(Debug, Clone, PartialEq, Eq, Hash)]
449struct CascadeChildDeleteKey {
450 table: &'static str,
451 fk_cols: Vec<&'static str>,
452}
453
454pub struct Session<C: Connection> {
463 connection: C,
465 in_transaction: bool,
467 identity_map: HashMap<ObjectKey, TrackedObject>,
469 pending_new: Vec<ObjectKey>,
471 pending_delete: Vec<ObjectKey>,
473 pending_dirty: Vec<ObjectKey>,
475 config: SessionConfig,
477 n1_tracker: Option<N1QueryTracker>,
479 event_callbacks: SessionEventCallbacks,
481}
482
483impl<C: Connection> Session<C> {
484 pub fn new(connection: C) -> Self {
486 Self::with_config(connection, SessionConfig::default())
487 }
488
489 pub fn with_config(connection: C, config: SessionConfig) -> Self {
491 Self {
492 connection,
493 in_transaction: false,
494 identity_map: HashMap::new(),
495 pending_new: Vec::new(),
496 pending_delete: Vec::new(),
497 pending_dirty: Vec::new(),
498 config,
499 n1_tracker: None,
500 event_callbacks: SessionEventCallbacks::default(),
501 }
502 }
503
504 pub fn connection(&self) -> &C {
506 &self.connection
507 }
508
509 pub fn config(&self) -> &SessionConfig {
511 &self.config
512 }
513
514 pub fn on_before_flush(&mut self, f: impl FnMut() -> Result<(), Error> + Send + 'static) {
522 self.event_callbacks.before_flush.push(Box::new(f));
523 }
524
525 pub fn on_after_flush(&mut self, f: impl FnMut() -> Result<(), Error> + Send + 'static) {
527 self.event_callbacks.after_flush.push(Box::new(f));
528 }
529
530 pub fn on_before_commit(&mut self, f: impl FnMut() -> Result<(), Error> + Send + 'static) {
534 self.event_callbacks.before_commit.push(Box::new(f));
535 }
536
537 pub fn on_after_commit(&mut self, f: impl FnMut() -> Result<(), Error> + Send + 'static) {
539 self.event_callbacks.after_commit.push(Box::new(f));
540 }
541
542 pub fn on_after_rollback(&mut self, f: impl FnMut() -> Result<(), Error> + Send + 'static) {
544 self.event_callbacks.after_rollback.push(Box::new(f));
545 }
546
547 pub fn add<M: Model + Clone + Send + Sync + Serialize + 'static>(&mut self, obj: &M) {
555 let key = ObjectKey::from_model(obj);
556
557 if let Some(tracked) = self.identity_map.get_mut(&key) {
559 tracked.object = Box::new(obj.clone());
560
561 let row_data = obj.to_row();
563 tracked.column_names = row_data.iter().map(|(name, _)| *name).collect();
564 tracked.values = row_data.into_iter().map(|(_, v)| v).collect();
565 tracked.pk_values = obj.primary_key_value();
566
567 if tracked.state == ObjectState::Deleted {
568 self.pending_delete.retain(|k| k != &key);
570
571 if tracked.original_state.is_some() {
572 tracked.state = ObjectState::Persistent;
574 } else {
575 tracked.state = ObjectState::New;
577 if !self.pending_new.contains(&key) {
578 self.pending_new.push(key);
579 }
580 }
581 }
582 return;
583 }
584
585 let row_data = obj.to_row();
587 let column_names: Vec<&'static str> = row_data.iter().map(|(name, _)| *name).collect();
588 let values: Vec<Value> = row_data.into_iter().map(|(_, v)| v).collect();
589
590 let pk_columns: Vec<&'static str> = M::PRIMARY_KEY.to_vec();
592 let pk_values = obj.primary_key_value();
593
594 let tracked = TrackedObject {
595 object: Box::new(obj.clone()),
596 original_state: None, state: ObjectState::New,
598 table_name: M::TABLE_NAME,
599 column_names,
600 values,
601 pk_columns,
602 pk_values,
603 relationships: M::RELATIONSHIPS,
604 expired_attributes: None,
605 };
606
607 self.identity_map.insert(key, tracked);
608 self.pending_new.push(key);
609 }
610
611 pub fn add_all<'a, M, I>(&mut self, objects: I)
628 where
629 M: Model + Clone + Send + Sync + Serialize + 'static,
630 I: IntoIterator<Item = &'a M>,
631 {
632 for obj in objects {
633 self.add(obj);
634 }
635 }
636
637 pub fn delete<M: Model + 'static>(&mut self, obj: &M) {
641 let key = ObjectKey::from_model(obj);
642
643 if let Some(tracked) = self.identity_map.get_mut(&key) {
644 match tracked.state {
645 ObjectState::New => {
646 self.identity_map.remove(&key);
648 self.pending_new.retain(|k| k != &key);
649 }
650 ObjectState::Persistent | ObjectState::Expired => {
651 tracked.state = ObjectState::Deleted;
652 self.pending_delete.push(key);
653 self.pending_dirty.retain(|k| k != &key);
654 }
655 ObjectState::Deleted | ObjectState::Detached => {
656 }
658 }
659 }
660 }
661
662 pub fn mark_dirty<M: Model + Clone + Send + Sync + Serialize + 'static>(&mut self, obj: &M) {
676 let key = ObjectKey::from_model(obj);
677
678 if let Some(tracked) = self.identity_map.get_mut(&key) {
679 if tracked.state != ObjectState::Persistent {
681 return;
682 }
683
684 tracked.object = Box::new(obj.clone());
686 let row_data = obj.to_row();
687 tracked.column_names = row_data.iter().map(|(name, _)| *name).collect();
688 tracked.values = row_data.into_iter().map(|(_, v)| v).collect();
689 tracked.pk_values = obj.primary_key_value();
690
691 if !self.pending_dirty.contains(&key) {
693 self.pending_dirty.push(key);
694 }
695 }
696 }
697
698 pub async fn get<
702 M: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
703 >(
704 &mut self,
705 cx: &Cx,
706 pk: impl Into<Value>,
707 ) -> Outcome<Option<M>, Error> {
708 let pk_value = pk.into();
709 let pk_values = vec![pk_value.clone()];
710 let key = ObjectKey::from_pk::<M>(&pk_values);
711
712 if let Some(tracked) = self.identity_map.get(&key) {
714 match tracked.state {
715 ObjectState::Deleted | ObjectState::Detached => {
716 }
718 ObjectState::Expired => {
719 tracing::debug!("Object is expired, reloading from database");
721 }
722 ObjectState::New | ObjectState::Persistent => {
723 if let Some(obj) = tracked.object.downcast_ref::<M>() {
724 return Outcome::Ok(Some(obj.clone()));
725 }
726 }
727 }
728 }
729
730 let pk_col = M::PRIMARY_KEY.first().unwrap_or(&"id");
732 let sql = format!(
733 "SELECT * FROM \"{}\" WHERE \"{}\" = $1 LIMIT 1",
734 M::TABLE_NAME,
735 pk_col
736 );
737
738 let rows = match self.connection.query(cx, &sql, &[pk_value]).await {
739 Outcome::Ok(rows) => rows,
740 Outcome::Err(e) => return Outcome::Err(e),
741 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
742 Outcome::Panicked(p) => return Outcome::Panicked(p),
743 };
744
745 if rows.is_empty() {
746 return Outcome::Ok(None);
747 }
748
749 let obj = match M::from_row(&rows[0]) {
751 Ok(obj) => obj,
752 Err(e) => return Outcome::Err(e),
753 };
754
755 let row_data = obj.to_row();
757 let column_names: Vec<&'static str> = row_data.iter().map(|(name, _)| *name).collect();
758 let values: Vec<Value> = row_data.into_iter().map(|(_, v)| v).collect();
759
760 let serialized = serde_json::to_vec(&values).ok();
762
763 let pk_columns: Vec<&'static str> = M::PRIMARY_KEY.to_vec();
765 let obj_pk_values = obj.primary_key_value();
766
767 let tracked = TrackedObject {
768 object: Box::new(obj.clone()),
769 original_state: serialized,
770 state: ObjectState::Persistent,
771 table_name: M::TABLE_NAME,
772 column_names,
773 values,
774 pk_columns,
775 pk_values: obj_pk_values,
776 relationships: M::RELATIONSHIPS,
777 expired_attributes: None,
778 };
779
780 self.identity_map.insert(key, tracked);
781
782 Outcome::Ok(Some(obj))
783 }
784
785 pub async fn get_by_pk<
799 M: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
800 >(
801 &mut self,
802 cx: &Cx,
803 pk_values: &[Value],
804 ) -> Outcome<Option<M>, Error> {
805 self.get_with_options::<M>(cx, pk_values, &GetOptions::default())
806 .await
807 }
808
809 pub async fn get_with_options<
822 M: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
823 >(
824 &mut self,
825 cx: &Cx,
826 pk_values: &[Value],
827 options: &GetOptions,
828 ) -> Outcome<Option<M>, Error> {
829 let key = ObjectKey::from_pk::<M>(pk_values);
830
831 if !options.with_for_update {
833 if let Some(tracked) = self.identity_map.get(&key) {
834 match tracked.state {
835 ObjectState::Deleted | ObjectState::Detached => {
836 }
838 ObjectState::Expired => {
839 tracing::debug!("Object is expired, reloading from database");
841 }
842 ObjectState::New | ObjectState::Persistent => {
843 if let Some(obj) = tracked.object.downcast_ref::<M>() {
844 return Outcome::Ok(Some(obj.clone()));
845 }
846 }
847 }
848 }
849 }
850
851 let pk_columns = M::PRIMARY_KEY;
853 if pk_columns.len() != pk_values.len() {
854 return Outcome::Err(Error::Custom(format!(
855 "Primary key mismatch: expected {} values, got {}",
856 pk_columns.len(),
857 pk_values.len()
858 )));
859 }
860
861 let where_parts: Vec<String> = pk_columns
862 .iter()
863 .enumerate()
864 .map(|(i, col)| format!("\"{}\" = ${}", col, i + 1))
865 .collect();
866
867 let mut sql = format!(
868 "SELECT * FROM \"{}\" WHERE {} LIMIT 1",
869 M::TABLE_NAME,
870 where_parts.join(" AND ")
871 );
872
873 if options.with_for_update {
875 sql.push_str(" FOR UPDATE");
876 if options.skip_locked {
877 sql.push_str(" SKIP LOCKED");
878 } else if options.nowait {
879 sql.push_str(" NOWAIT");
880 }
881 }
882
883 let rows = match self.connection.query(cx, &sql, pk_values).await {
884 Outcome::Ok(rows) => rows,
885 Outcome::Err(e) => return Outcome::Err(e),
886 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
887 Outcome::Panicked(p) => return Outcome::Panicked(p),
888 };
889
890 if rows.is_empty() {
891 return Outcome::Ok(None);
892 }
893
894 let obj = match M::from_row(&rows[0]) {
896 Ok(obj) => obj,
897 Err(e) => return Outcome::Err(e),
898 };
899
900 let row_data = obj.to_row();
902 let column_names: Vec<&'static str> = row_data.iter().map(|(name, _)| *name).collect();
903 let values: Vec<Value> = row_data.into_iter().map(|(_, v)| v).collect();
904
905 let serialized = serde_json::to_vec(&values).ok();
907
908 let pk_cols: Vec<&'static str> = M::PRIMARY_KEY.to_vec();
910 let obj_pk_values = obj.primary_key_value();
911
912 let tracked = TrackedObject {
913 object: Box::new(obj.clone()),
914 original_state: serialized,
915 state: ObjectState::Persistent,
916 table_name: M::TABLE_NAME,
917 column_names,
918 values,
919 pk_columns: pk_cols,
920 pk_values: obj_pk_values,
921 relationships: M::RELATIONSHIPS,
922 expired_attributes: None,
923 };
924
925 self.identity_map.insert(key, tracked);
926
927 Outcome::Ok(Some(obj))
928 }
929
930 pub fn contains<M: Model + 'static>(&self, obj: &M) -> bool {
932 let key = ObjectKey::from_model(obj);
933 self.identity_map.contains_key(&key)
934 }
935
936 pub fn expunge<M: Model + 'static>(&mut self, obj: &M) {
938 let key = ObjectKey::from_model(obj);
939 if let Some(tracked) = self.identity_map.get_mut(&key) {
940 tracked.state = ObjectState::Detached;
941 }
942 self.pending_new.retain(|k| k != &key);
943 self.pending_delete.retain(|k| k != &key);
944 self.pending_dirty.retain(|k| k != &key);
945 }
946
947 pub fn expunge_all(&mut self) {
949 for tracked in self.identity_map.values_mut() {
950 tracked.state = ObjectState::Detached;
951 }
952 self.pending_new.clear();
953 self.pending_delete.clear();
954 self.pending_dirty.clear();
955 }
956
957 pub fn is_modified<M: Model + Serialize + 'static>(&self, obj: &M) -> bool {
986 let key = ObjectKey::from_model(obj);
987
988 let Some(tracked) = self.identity_map.get(&key) else {
989 return false;
990 };
991
992 match tracked.state {
993 ObjectState::New => true,
995
996 ObjectState::Deleted => true,
998
999 ObjectState::Detached | ObjectState::Expired => false,
1001
1002 ObjectState::Persistent => {
1004 if self.pending_dirty.contains(&key) {
1006 return true;
1007 }
1008
1009 let current_state = serde_json::to_vec(&tracked.values).unwrap_or_default();
1011 tracked.original_state.as_ref() != Some(¤t_state)
1012 }
1013 }
1014 }
1015
1016 pub fn modified_attributes<M: Model + Serialize + 'static>(
1035 &self,
1036 obj: &M,
1037 ) -> Vec<&'static str> {
1038 let key = ObjectKey::from_model(obj);
1039
1040 let Some(tracked) = self.identity_map.get(&key) else {
1041 return Vec::new();
1042 };
1043
1044 if tracked.state != ObjectState::Persistent {
1046 return Vec::new();
1047 }
1048
1049 let Some(original_bytes) = &tracked.original_state else {
1051 return Vec::new();
1052 };
1053
1054 let Ok(original_values): Result<Vec<Value>, _> = serde_json::from_slice(original_bytes)
1056 else {
1057 return Vec::new();
1058 };
1059
1060 let mut modified = Vec::new();
1062 for (i, col) in tracked.column_names.iter().enumerate() {
1063 let current = tracked.values.get(i);
1064 let original = original_values.get(i);
1065
1066 if current != original {
1067 modified.push(*col);
1068 }
1069 }
1070
1071 modified
1072 }
1073
1074 pub fn object_state<M: Model + 'static>(&self, obj: &M) -> Option<ObjectState> {
1078 let key = ObjectKey::from_model(obj);
1079 self.identity_map.get(&key).map(|t| t.state)
1080 }
1081
1082 #[tracing::instrument(level = "debug", skip(self, obj), fields(table = M::TABLE_NAME))]
1116 pub fn expire<M: Model + 'static>(&mut self, obj: &M, attributes: Option<&[&str]>) {
1117 let key = ObjectKey::from_model(obj);
1118
1119 let Some(tracked) = self.identity_map.get_mut(&key) else {
1120 tracing::debug!("Object not tracked, nothing to expire");
1121 return;
1122 };
1123
1124 match tracked.state {
1126 ObjectState::New | ObjectState::Detached | ObjectState::Deleted => {
1127 tracing::debug!(state = ?tracked.state, "Cannot expire object in this state");
1128 return;
1129 }
1130 ObjectState::Persistent | ObjectState::Expired => {}
1131 }
1132
1133 match attributes {
1134 None => {
1135 tracked.state = ObjectState::Expired;
1137 tracked.expired_attributes = None;
1138 tracing::debug!("Expired all attributes");
1139 }
1140 Some(attrs) => {
1141 let mut expired = tracked.expired_attributes.take().unwrap_or_default();
1143 for attr in attrs {
1144 expired.insert((*attr).to_string());
1145 }
1146 tracked.expired_attributes = Some(expired);
1147
1148 if tracked.state == ObjectState::Persistent {
1150 tracked.state = ObjectState::Expired;
1151 }
1152 tracing::debug!(attributes = ?attrs, "Expired specific attributes");
1153 }
1154 }
1155 }
1156
1157 #[tracing::instrument(level = "debug", skip(self))]
1178 pub fn expire_all(&mut self) {
1179 let mut expired_count = 0;
1180 for tracked in self.identity_map.values_mut() {
1181 if tracked.state == ObjectState::Persistent {
1182 tracked.state = ObjectState::Expired;
1183 tracked.expired_attributes = None;
1184 expired_count += 1;
1185 }
1186 }
1187 tracing::debug!(count = expired_count, "Expired all session objects");
1188 }
1189
1190 pub fn is_expired<M: Model + 'static>(&self, obj: &M) -> bool {
1195 let key = ObjectKey::from_model(obj);
1196 self.identity_map
1197 .get(&key)
1198 .is_some_and(|t| t.state == ObjectState::Expired)
1199 }
1200
1201 pub fn expired_attributes<M: Model + 'static>(
1208 &self,
1209 obj: &M,
1210 ) -> Option<Option<&std::collections::HashSet<String>>> {
1211 let key = ObjectKey::from_model(obj);
1212 let tracked = self.identity_map.get(&key)?;
1213
1214 if tracked.state != ObjectState::Expired {
1215 return None;
1216 }
1217
1218 Some(tracked.expired_attributes.as_ref())
1219 }
1220
1221 #[tracing::instrument(level = "debug", skip(self, cx, obj), fields(table = M::TABLE_NAME))]
1256 pub async fn refresh<
1257 M: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
1258 >(
1259 &mut self,
1260 cx: &Cx,
1261 obj: &M,
1262 ) -> Outcome<Option<M>, Error> {
1263 let pk_values = obj.primary_key_value();
1264 let key = ObjectKey::from_model(obj);
1265
1266 tracing::debug!(pk = ?pk_values, "Refreshing object from database");
1267
1268 self.pending_dirty.retain(|k| k != &key);
1270
1271 self.identity_map.remove(&key);
1273
1274 let result = self.get_by_pk::<M>(cx, &pk_values).await;
1276
1277 match &result {
1278 Outcome::Ok(Some(_)) => {
1279 tracing::debug!("Object refreshed successfully");
1280 }
1281 Outcome::Ok(None) => {
1282 tracing::debug!("Object no longer exists in database");
1283 }
1284 _ => {}
1285 }
1286
1287 result
1288 }
1289
1290 pub async fn begin(&mut self, cx: &Cx) -> Outcome<(), Error> {
1296 if self.in_transaction {
1297 return Outcome::Ok(());
1298 }
1299
1300 match self.connection.execute(cx, "BEGIN", &[]).await {
1301 Outcome::Ok(_) => {
1302 self.in_transaction = true;
1303 Outcome::Ok(())
1304 }
1305 Outcome::Err(e) => Outcome::Err(e),
1306 Outcome::Cancelled(r) => Outcome::Cancelled(r),
1307 Outcome::Panicked(p) => Outcome::Panicked(p),
1308 }
1309 }
1310
1311 pub async fn flush(&mut self, cx: &Cx) -> Outcome<(), Error> {
1315 if let Err(e) = self.event_callbacks.fire(SessionEvent::BeforeFlush) {
1317 return Outcome::Err(e);
1318 }
1319
1320 if self.config.auto_begin && !self.in_transaction {
1322 match self.begin(cx).await {
1323 Outcome::Ok(()) => {}
1324 Outcome::Err(e) => return Outcome::Err(e),
1325 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1326 Outcome::Panicked(p) => return Outcome::Panicked(p),
1327 }
1328 }
1329
1330 let dialect = self.connection.dialect();
1331
1332 let deletes: Vec<ObjectKey> = std::mem::take(&mut self.pending_delete);
1334
1335 let mut cascade_child_deletes_single: HashMap<(&'static str, &'static str), Vec<Value>> =
1340 HashMap::new();
1341 let mut cascade_child_deletes_composite: HashMap<CascadeChildDeleteKey, Vec<Vec<Value>>> =
1342 HashMap::new();
1343 let mut cascade_link_deletes_single: HashMap<(&'static str, &'static str), Vec<Value>> =
1344 HashMap::new();
1345 let mut cascade_link_deletes_composite: HashMap<CascadeChildDeleteKey, Vec<Vec<Value>>> =
1346 HashMap::new();
1347
1348 for key in &deletes {
1349 let Some(tracked) = self.identity_map.get(key) else {
1350 continue;
1351 };
1352 if tracked.state != ObjectState::Deleted {
1353 continue;
1354 }
1355 let parent_pk_values = tracked.pk_values.clone();
1356
1357 for rel in tracked.relationships {
1358 if !rel.cascade_delete || rel.is_passive_deletes_all() {
1359 continue;
1360 }
1361
1362 match rel.kind {
1363 sqlmodel_core::RelationshipKind::OneToMany
1364 | sqlmodel_core::RelationshipKind::OneToOne => {
1365 if matches!(rel.passive_deletes, sqlmodel_core::PassiveDeletes::Passive) {
1368 continue;
1369 }
1370 let fk_cols = rel.remote_key_cols();
1371 if fk_cols.is_empty() {
1372 continue;
1373 }
1374 if fk_cols.len() == 1 && parent_pk_values.len() == 1 {
1375 cascade_child_deletes_single
1376 .entry((rel.related_table, fk_cols[0]))
1377 .or_default()
1378 .push(parent_pk_values[0].clone());
1379 } else {
1380 if fk_cols.len() != parent_pk_values.len() {
1382 continue;
1383 }
1384 cascade_child_deletes_composite
1385 .entry(CascadeChildDeleteKey {
1386 table: rel.related_table,
1387 fk_cols: fk_cols.to_vec(),
1388 })
1389 .or_default()
1390 .push(parent_pk_values.clone());
1391 }
1392 }
1393 sqlmodel_core::RelationshipKind::ManyToMany => {
1394 if matches!(rel.passive_deletes, sqlmodel_core::PassiveDeletes::Passive) {
1395 continue;
1396 }
1397 let Some(link) = rel.link_table else {
1398 continue;
1399 };
1400 let local_cols = link.local_cols();
1401 if local_cols.is_empty() {
1402 continue;
1403 }
1404 if local_cols.len() == 1 && parent_pk_values.len() == 1 {
1405 cascade_link_deletes_single
1406 .entry((link.table_name, local_cols[0]))
1407 .or_default()
1408 .push(parent_pk_values[0].clone());
1409 } else {
1410 if local_cols.len() != parent_pk_values.len() {
1411 continue;
1412 }
1413 cascade_link_deletes_composite
1414 .entry(CascadeChildDeleteKey {
1415 table: link.table_name,
1416 fk_cols: local_cols.to_vec(),
1417 })
1418 .or_default()
1419 .push(parent_pk_values.clone());
1420 }
1421 }
1422 sqlmodel_core::RelationshipKind::ManyToOne => {}
1423 }
1424 }
1425 }
1426
1427 let dedup_by_hash = |vals: &mut Vec<Value>| {
1428 let mut seen: std::collections::HashSet<u64> = std::collections::HashSet::new();
1429 vals.retain(|v| seen.insert(hash_values(std::slice::from_ref(v))));
1430 };
1431
1432 for ((child_table, fk_col), mut pks) in cascade_child_deletes_single {
1434 dedup_by_hash(&mut pks);
1435 if pks.is_empty() {
1436 continue;
1437 }
1438
1439 let placeholders: Vec<String> =
1440 (1..=pks.len()).map(|i| dialect.placeholder(i)).collect();
1441 let sql = format!(
1442 "DELETE FROM {} WHERE {} IN ({})",
1443 dialect.quote_identifier(child_table),
1444 dialect.quote_identifier(fk_col),
1445 placeholders.join(", ")
1446 );
1447
1448 match self.connection.execute(cx, &sql, &pks).await {
1449 Outcome::Ok(_) => {}
1450 Outcome::Err(e) => {
1451 self.pending_delete = deletes;
1452 return Outcome::Err(e);
1453 }
1454 Outcome::Cancelled(r) => {
1455 self.pending_delete = deletes;
1456 return Outcome::Cancelled(r);
1457 }
1458 Outcome::Panicked(p) => {
1459 self.pending_delete = deletes;
1460 return Outcome::Panicked(p);
1461 }
1462 }
1463
1464 let pk_hashes: std::collections::HashSet<u64> = pks
1466 .iter()
1467 .map(|v| hash_values(std::slice::from_ref(v)))
1468 .collect();
1469 let mut to_remove: Vec<ObjectKey> = Vec::new();
1470 for (k, t) in &self.identity_map {
1471 if t.table_name != child_table {
1472 continue;
1473 }
1474 let Some(idx) = t.column_names.iter().position(|col| *col == fk_col) else {
1475 continue;
1476 };
1477 let fk_val = &t.values[idx];
1478 if pk_hashes.contains(&hash_values(std::slice::from_ref(fk_val))) {
1479 to_remove.push(*k);
1480 }
1481 }
1482 for k in &to_remove {
1483 self.identity_map.remove(k);
1484 }
1485 self.pending_new.retain(|k| !to_remove.contains(k));
1486 self.pending_dirty.retain(|k| !to_remove.contains(k));
1487 self.pending_delete.retain(|k| !to_remove.contains(k));
1488 }
1489
1490 for (key, mut tuples) in cascade_child_deletes_composite {
1492 if tuples.is_empty() {
1493 continue;
1494 }
1495
1496 let mut seen: std::collections::HashSet<u64> = std::collections::HashSet::new();
1497 tuples.retain(|t| seen.insert(hash_values(t)));
1498
1499 if tuples.is_empty() {
1500 continue;
1501 }
1502
1503 let col_list = key
1504 .fk_cols
1505 .iter()
1506 .map(|c| dialect.quote_identifier(c))
1507 .collect::<Vec<_>>()
1508 .join(", ");
1509
1510 let mut params: Vec<Value> = Vec::with_capacity(tuples.len() * key.fk_cols.len());
1511 let mut idx = 1;
1512 let tuple_sql: Vec<String> = tuples
1513 .iter()
1514 .map(|t| {
1515 for v in t {
1516 params.push(v.clone());
1517 }
1518 let inner = (0..key.fk_cols.len())
1519 .map(|_| {
1520 let ph = dialect.placeholder(idx);
1521 idx += 1;
1522 ph
1523 })
1524 .collect::<Vec<_>>()
1525 .join(", ");
1526 format!("({})", inner)
1527 })
1528 .collect();
1529
1530 let sql = format!(
1531 "DELETE FROM {} WHERE ({}) IN ({})",
1532 dialect.quote_identifier(key.table),
1533 col_list,
1534 tuple_sql.join(", ")
1535 );
1536
1537 match self.connection.execute(cx, &sql, ¶ms).await {
1538 Outcome::Ok(_) => {}
1539 Outcome::Err(e) => {
1540 self.pending_delete = deletes;
1541 return Outcome::Err(e);
1542 }
1543 Outcome::Cancelled(r) => {
1544 self.pending_delete = deletes;
1545 return Outcome::Cancelled(r);
1546 }
1547 Outcome::Panicked(p) => {
1548 self.pending_delete = deletes;
1549 return Outcome::Panicked(p);
1550 }
1551 }
1552
1553 let tuple_hashes: std::collections::HashSet<u64> =
1555 tuples.iter().map(|t| hash_values(t)).collect();
1556 let mut to_remove: Vec<ObjectKey> = Vec::new();
1557 for (k, t) in &self.identity_map {
1558 if t.table_name != key.table {
1559 continue;
1560 }
1561
1562 let mut child_fk: Vec<Value> = Vec::with_capacity(key.fk_cols.len());
1563 let mut missing = false;
1564 for fk_col in &key.fk_cols {
1565 let Some(idx) = t.column_names.iter().position(|col| col == fk_col) else {
1566 missing = true;
1567 break;
1568 };
1569 child_fk.push(t.values[idx].clone());
1570 }
1571 if missing {
1572 continue;
1573 }
1574 if tuple_hashes.contains(&hash_values(&child_fk)) {
1575 to_remove.push(*k);
1576 }
1577 }
1578 for k in &to_remove {
1579 self.identity_map.remove(k);
1580 }
1581 self.pending_new.retain(|k| !to_remove.contains(k));
1582 self.pending_dirty.retain(|k| !to_remove.contains(k));
1583 self.pending_delete.retain(|k| !to_remove.contains(k));
1584 }
1585
1586 for ((link_table, local_col), mut pks) in cascade_link_deletes_single {
1588 dedup_by_hash(&mut pks);
1589 if pks.is_empty() {
1590 continue;
1591 }
1592
1593 let placeholders: Vec<String> =
1594 (1..=pks.len()).map(|i| dialect.placeholder(i)).collect();
1595 let sql = format!(
1596 "DELETE FROM {} WHERE {} IN ({})",
1597 dialect.quote_identifier(link_table),
1598 dialect.quote_identifier(local_col),
1599 placeholders.join(", ")
1600 );
1601
1602 match self.connection.execute(cx, &sql, &pks).await {
1603 Outcome::Ok(_) => {}
1604 Outcome::Err(e) => {
1605 self.pending_delete = deletes;
1606 return Outcome::Err(e);
1607 }
1608 Outcome::Cancelled(r) => {
1609 self.pending_delete = deletes;
1610 return Outcome::Cancelled(r);
1611 }
1612 Outcome::Panicked(p) => {
1613 self.pending_delete = deletes;
1614 return Outcome::Panicked(p);
1615 }
1616 }
1617 }
1618
1619 for (key, mut tuples) in cascade_link_deletes_composite {
1621 if tuples.is_empty() {
1622 continue;
1623 }
1624
1625 let mut seen: std::collections::HashSet<u64> = std::collections::HashSet::new();
1626 tuples.retain(|t| seen.insert(hash_values(t)));
1627
1628 if tuples.is_empty() {
1629 continue;
1630 }
1631
1632 let col_list = key
1633 .fk_cols
1634 .iter()
1635 .map(|c| dialect.quote_identifier(c))
1636 .collect::<Vec<_>>()
1637 .join(", ");
1638
1639 let mut params: Vec<Value> = Vec::with_capacity(tuples.len() * key.fk_cols.len());
1640 let mut idx = 1;
1641 let tuple_sql: Vec<String> = tuples
1642 .iter()
1643 .map(|t| {
1644 for v in t {
1645 params.push(v.clone());
1646 }
1647 let inner = (0..key.fk_cols.len())
1648 .map(|_| {
1649 let ph = dialect.placeholder(idx);
1650 idx += 1;
1651 ph
1652 })
1653 .collect::<Vec<_>>()
1654 .join(", ");
1655 format!("({})", inner)
1656 })
1657 .collect();
1658
1659 let sql = format!(
1660 "DELETE FROM {} WHERE ({}) IN ({})",
1661 dialect.quote_identifier(key.table),
1662 col_list,
1663 tuple_sql.join(", ")
1664 );
1665
1666 match self.connection.execute(cx, &sql, ¶ms).await {
1667 Outcome::Ok(_) => {}
1668 Outcome::Err(e) => {
1669 self.pending_delete = deletes;
1670 return Outcome::Err(e);
1671 }
1672 Outcome::Cancelled(r) => {
1673 self.pending_delete = deletes;
1674 return Outcome::Cancelled(r);
1675 }
1676 Outcome::Panicked(p) => {
1677 self.pending_delete = deletes;
1678 return Outcome::Panicked(p);
1679 }
1680 }
1681 }
1682
1683 let mut actually_deleted: Vec<ObjectKey> = Vec::new();
1684 for key in &deletes {
1685 if let Some(tracked) = self.identity_map.get(key) {
1686 if tracked.state != ObjectState::Deleted {
1688 continue;
1689 }
1690
1691 if tracked.pk_columns.is_empty() || tracked.pk_values.is_empty() {
1693 tracing::warn!(
1694 table = tracked.table_name,
1695 "Skipping DELETE for object without primary key - cannot identify row"
1696 );
1697 continue;
1698 }
1699
1700 let pk_columns = tracked.pk_columns.clone();
1702 let pk_values = tracked.pk_values.clone();
1703 let table_name = tracked.table_name;
1704 let relationships = tracked.relationships;
1705
1706 let where_parts: Vec<String> = pk_columns
1708 .iter()
1709 .enumerate()
1710 .map(|(i, col)| {
1711 format!(
1712 "{} = {}",
1713 dialect.quote_identifier(col),
1714 dialect.placeholder(i + 1)
1715 )
1716 })
1717 .collect();
1718
1719 let sql = format!(
1720 "DELETE FROM {} WHERE {}",
1721 dialect.quote_identifier(table_name),
1722 where_parts.join(" AND ")
1723 );
1724
1725 match self.connection.execute(cx, &sql, &pk_values).await {
1726 Outcome::Ok(_) => {
1727 actually_deleted.push(*key);
1728
1729 if !pk_values.is_empty() {
1732 let mut to_remove: Vec<ObjectKey> = Vec::new();
1733 for rel in relationships {
1734 if !rel.cascade_delete
1735 || !matches!(
1736 rel.passive_deletes,
1737 sqlmodel_core::PassiveDeletes::Passive
1738 )
1739 {
1740 continue;
1741 }
1742 if !matches!(
1743 rel.kind,
1744 sqlmodel_core::RelationshipKind::OneToMany
1745 | sqlmodel_core::RelationshipKind::OneToOne
1746 ) {
1747 continue;
1748 }
1749
1750 let fk_cols = rel.remote_key_cols();
1751 if fk_cols.is_empty() || fk_cols.len() != pk_values.len() {
1752 continue;
1753 }
1754
1755 for (k, t) in &self.identity_map {
1756 if t.table_name != rel.related_table {
1757 continue;
1758 }
1759 let mut matches_parent = true;
1760 for (fk_col, parent_val) in fk_cols.iter().zip(&pk_values) {
1761 let Some(idx) =
1762 t.column_names.iter().position(|col| col == fk_col)
1763 else {
1764 matches_parent = false;
1765 break;
1766 };
1767 if &t.values[idx] != parent_val {
1768 matches_parent = false;
1769 break;
1770 }
1771 }
1772 if matches_parent {
1773 to_remove.push(*k);
1774 }
1775 }
1776 }
1777
1778 for k in &to_remove {
1779 self.identity_map.remove(k);
1780 }
1781 self.pending_new.retain(|k| !to_remove.contains(k));
1782 self.pending_dirty.retain(|k| !to_remove.contains(k));
1783 self.pending_delete.retain(|k| !to_remove.contains(k));
1784 }
1785 }
1786 Outcome::Err(e) => {
1787 self.pending_delete = deletes
1790 .into_iter()
1791 .filter(|k| !actually_deleted.contains(k))
1792 .collect();
1793 for key in &actually_deleted {
1795 self.identity_map.remove(key);
1796 }
1797 return Outcome::Err(e);
1798 }
1799 Outcome::Cancelled(r) => {
1800 self.pending_delete = deletes
1802 .into_iter()
1803 .filter(|k| !actually_deleted.contains(k))
1804 .collect();
1805 for key in &actually_deleted {
1806 self.identity_map.remove(key);
1807 }
1808 return Outcome::Cancelled(r);
1809 }
1810 Outcome::Panicked(p) => {
1811 self.pending_delete = deletes
1813 .into_iter()
1814 .filter(|k| !actually_deleted.contains(k))
1815 .collect();
1816 for key in &actually_deleted {
1817 self.identity_map.remove(key);
1818 }
1819 return Outcome::Panicked(p);
1820 }
1821 }
1822 }
1823 }
1824
1825 for key in &actually_deleted {
1827 self.identity_map.remove(key);
1828 }
1829
1830 let inserts: Vec<ObjectKey> = std::mem::take(&mut self.pending_new);
1832 for key in &inserts {
1833 if let Some(tracked) = self.identity_map.get_mut(key) {
1834 if tracked.state == ObjectState::Persistent {
1836 continue;
1837 }
1838
1839 let columns = &tracked.column_names;
1841 let columns_sql: Vec<String> = columns
1842 .iter()
1843 .map(|c| dialect.quote_identifier(c))
1844 .collect();
1845 let placeholders: Vec<String> = (1..=columns.len())
1846 .map(|i| dialect.placeholder(i))
1847 .collect();
1848
1849 let sql = format!(
1850 "INSERT INTO {} ({}) VALUES ({})",
1851 dialect.quote_identifier(tracked.table_name),
1852 columns_sql.join(", "),
1853 placeholders.join(", ")
1854 );
1855
1856 match self.connection.execute(cx, &sql, &tracked.values).await {
1857 Outcome::Ok(_) => {
1858 tracked.state = ObjectState::Persistent;
1859 tracked.original_state =
1861 Some(serde_json::to_vec(&tracked.values).unwrap_or_default());
1862 }
1863 Outcome::Err(e) => {
1864 self.pending_new = inserts;
1866 return Outcome::Err(e);
1867 }
1868 Outcome::Cancelled(r) => {
1869 self.pending_new = inserts;
1871 return Outcome::Cancelled(r);
1872 }
1873 Outcome::Panicked(p) => {
1874 self.pending_new = inserts;
1876 return Outcome::Panicked(p);
1877 }
1878 }
1879 }
1880 }
1881
1882 let dirty: Vec<ObjectKey> = std::mem::take(&mut self.pending_dirty);
1884 for key in &dirty {
1885 if let Some(tracked) = self.identity_map.get_mut(key) {
1886 if tracked.state != ObjectState::Persistent {
1888 continue;
1889 }
1890
1891 if tracked.pk_columns.is_empty() || tracked.pk_values.is_empty() {
1893 tracing::warn!(
1894 table = tracked.table_name,
1895 "Skipping UPDATE for object without primary key - cannot identify row"
1896 );
1897 continue;
1898 }
1899
1900 let current_state = serde_json::to_vec(&tracked.values).unwrap_or_default();
1902 let is_dirty = tracked.original_state.as_ref() != Some(¤t_state);
1903
1904 if !is_dirty {
1905 continue;
1906 }
1907
1908 let mut set_parts = Vec::new();
1910 let mut params = Vec::new();
1911 let mut param_idx = 1;
1912
1913 for (i, col) in tracked.column_names.iter().enumerate() {
1914 if !tracked.pk_columns.contains(col) {
1916 set_parts.push(format!(
1917 "{} = {}",
1918 dialect.quote_identifier(col),
1919 dialect.placeholder(param_idx)
1920 ));
1921 params.push(tracked.values[i].clone());
1922 param_idx += 1;
1923 }
1924 }
1925
1926 let where_parts: Vec<String> = tracked
1928 .pk_columns
1929 .iter()
1930 .map(|col| {
1931 let clause = format!(
1932 "{} = {}",
1933 dialect.quote_identifier(col),
1934 dialect.placeholder(param_idx)
1935 );
1936 param_idx += 1;
1937 clause
1938 })
1939 .collect();
1940
1941 params.extend(tracked.pk_values.clone());
1943
1944 if set_parts.is_empty() {
1945 continue; }
1947
1948 let sql = format!(
1949 "UPDATE {} SET {} WHERE {}",
1950 dialect.quote_identifier(tracked.table_name),
1951 set_parts.join(", "),
1952 where_parts.join(" AND ")
1953 );
1954
1955 match self.connection.execute(cx, &sql, ¶ms).await {
1956 Outcome::Ok(_) => {
1957 tracked.original_state = Some(current_state);
1959 }
1960 Outcome::Err(e) => {
1961 self.pending_dirty = dirty;
1963 return Outcome::Err(e);
1964 }
1965 Outcome::Cancelled(r) => {
1966 self.pending_dirty = dirty;
1968 return Outcome::Cancelled(r);
1969 }
1970 Outcome::Panicked(p) => {
1971 self.pending_dirty = dirty;
1973 return Outcome::Panicked(p);
1974 }
1975 }
1976 }
1977 }
1978
1979 if let Err(e) = self.event_callbacks.fire(SessionEvent::AfterFlush) {
1981 return Outcome::Err(e);
1982 }
1983
1984 Outcome::Ok(())
1985 }
1986
1987 pub async fn commit(&mut self, cx: &Cx) -> Outcome<(), Error> {
1989 match self.flush(cx).await {
1991 Outcome::Ok(()) => {}
1992 Outcome::Err(e) => return Outcome::Err(e),
1993 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1994 Outcome::Panicked(p) => return Outcome::Panicked(p),
1995 }
1996
1997 if let Err(e) = self.event_callbacks.fire(SessionEvent::BeforeCommit) {
1999 return Outcome::Err(e);
2000 }
2001
2002 if self.in_transaction {
2003 match self.connection.execute(cx, "COMMIT", &[]).await {
2004 Outcome::Ok(_) => {
2005 self.in_transaction = false;
2006 }
2007 Outcome::Err(e) => return Outcome::Err(e),
2008 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
2009 Outcome::Panicked(p) => return Outcome::Panicked(p),
2010 }
2011 }
2012
2013 if self.config.expire_on_commit {
2015 for tracked in self.identity_map.values_mut() {
2016 if tracked.state == ObjectState::Persistent {
2017 tracked.state = ObjectState::Expired;
2018 }
2019 }
2020 }
2021
2022 if let Err(e) = self.event_callbacks.fire(SessionEvent::AfterCommit) {
2024 return Outcome::Err(e);
2025 }
2026
2027 Outcome::Ok(())
2028 }
2029
2030 pub async fn rollback(&mut self, cx: &Cx) -> Outcome<(), Error> {
2032 if self.in_transaction {
2033 match self.connection.execute(cx, "ROLLBACK", &[]).await {
2034 Outcome::Ok(_) => {
2035 self.in_transaction = false;
2036 }
2037 Outcome::Err(e) => return Outcome::Err(e),
2038 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
2039 Outcome::Panicked(p) => return Outcome::Panicked(p),
2040 }
2041 }
2042
2043 self.pending_new.clear();
2045 self.pending_delete.clear();
2046 self.pending_dirty.clear();
2047
2048 let mut to_remove = Vec::new();
2050 for (key, tracked) in &mut self.identity_map {
2051 match tracked.state {
2052 ObjectState::New => {
2053 to_remove.push(*key);
2054 }
2055 ObjectState::Deleted => {
2056 tracked.state = ObjectState::Persistent;
2057 }
2058 _ => {}
2059 }
2060 }
2061
2062 for key in to_remove {
2063 self.identity_map.remove(&key);
2064 }
2065
2066 if let Err(e) = self.event_callbacks.fire(SessionEvent::AfterRollback) {
2068 return Outcome::Err(e);
2069 }
2070
2071 Outcome::Ok(())
2072 }
2073
2074 #[tracing::instrument(level = "debug", skip(self, lazy, cx))]
2090 pub async fn load_lazy<
2091 T: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
2092 >(
2093 &mut self,
2094 lazy: &Lazy<T>,
2095 cx: &Cx,
2096 ) -> Outcome<bool, Error> {
2097 tracing::debug!(
2098 model = std::any::type_name::<T>(),
2099 fk = ?lazy.fk(),
2100 already_loaded = lazy.is_loaded(),
2101 "Loading lazy relationship"
2102 );
2103
2104 if lazy.is_loaded() {
2106 tracing::trace!("Already loaded");
2107 return Outcome::Ok(lazy.get().is_some());
2108 }
2109
2110 let Some(fk) = lazy.fk() else {
2112 let _ = lazy.set_loaded(None);
2113 return Outcome::Ok(false);
2114 };
2115
2116 let obj = match self.get::<T>(cx, fk.clone()).await {
2118 Outcome::Ok(obj) => obj,
2119 Outcome::Err(e) => return Outcome::Err(e),
2120 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
2121 Outcome::Panicked(p) => return Outcome::Panicked(p),
2122 };
2123
2124 let found = obj.is_some();
2125
2126 let _ = lazy.set_loaded(obj);
2128
2129 tracing::debug!(found = found, "Lazy load complete");
2130
2131 Outcome::Ok(found)
2132 }
2133
2134 #[tracing::instrument(level = "debug", skip(self, cx, objects, accessor))]
2158 pub async fn load_many<P, T, F>(
2159 &mut self,
2160 cx: &Cx,
2161 objects: &[P],
2162 accessor: F,
2163 ) -> Outcome<usize, Error>
2164 where
2165 P: Model + 'static,
2166 T: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
2167 F: Fn(&P) -> &Lazy<T>,
2168 {
2169 let mut fk_values: Vec<Value> = Vec::new();
2171 let mut fk_indices: Vec<usize> = Vec::new();
2172
2173 for (idx, obj) in objects.iter().enumerate() {
2174 let lazy = accessor(obj);
2175 if !lazy.is_loaded() && !lazy.is_empty() {
2176 if let Some(fk) = lazy.fk() {
2177 fk_values.push(fk.clone());
2178 fk_indices.push(idx);
2179 }
2180 }
2181 }
2182
2183 let fk_count = fk_values.len();
2184 tracing::info!(
2185 parent_model = std::any::type_name::<P>(),
2186 related_model = std::any::type_name::<T>(),
2187 parent_count = objects.len(),
2188 fk_count = fk_count,
2189 "Batch loading lazy relationships"
2190 );
2191
2192 if fk_values.is_empty() {
2193 for obj in objects {
2195 let lazy = accessor(obj);
2196 if !lazy.is_loaded() && lazy.is_empty() {
2197 let _ = lazy.set_loaded(None);
2198 }
2199 }
2200 return Outcome::Ok(0);
2201 }
2202
2203 let dialect = self.connection.dialect();
2205 let pk_col = T::PRIMARY_KEY.first().unwrap_or(&"id");
2206 let placeholders: Vec<String> = (1..=fk_values.len())
2207 .map(|i| dialect.placeholder(i))
2208 .collect();
2209 let sql = format!(
2210 "SELECT * FROM {} WHERE {} IN ({})",
2211 dialect.quote_identifier(T::TABLE_NAME),
2212 dialect.quote_identifier(pk_col),
2213 placeholders.join(", ")
2214 );
2215
2216 let rows = match self.connection.query(cx, &sql, &fk_values).await {
2217 Outcome::Ok(rows) => rows,
2218 Outcome::Err(e) => return Outcome::Err(e),
2219 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
2220 Outcome::Panicked(p) => return Outcome::Panicked(p),
2221 };
2222
2223 let mut lookup: HashMap<u64, T> = HashMap::new();
2225 for row in &rows {
2226 match T::from_row(row) {
2227 Ok(obj) => {
2228 let pk_values = obj.primary_key_value();
2229 let pk_hash = hash_values(&pk_values);
2230
2231 let key = ObjectKey::from_pk::<T>(&pk_values);
2233
2234 let row_data = obj.to_row();
2236 let column_names: Vec<&'static str> =
2237 row_data.iter().map(|(name, _)| *name).collect();
2238 let values: Vec<Value> = row_data.into_iter().map(|(_, v)| v).collect();
2239
2240 let serialized = serde_json::to_vec(&values).ok();
2242
2243 let tracked = TrackedObject {
2244 object: Box::new(obj.clone()),
2245 original_state: serialized,
2246 state: ObjectState::Persistent,
2247 table_name: T::TABLE_NAME,
2248 column_names,
2249 values,
2250 pk_columns: T::PRIMARY_KEY.to_vec(),
2251 pk_values: pk_values.clone(),
2252 relationships: T::RELATIONSHIPS,
2253 expired_attributes: None,
2254 };
2255 self.identity_map.insert(key, tracked);
2256
2257 lookup.insert(pk_hash, obj);
2259 }
2260 Err(_) => continue,
2261 }
2262 }
2263
2264 let mut loaded_count = 0;
2266 for obj in objects {
2267 let lazy = accessor(obj);
2268 if !lazy.is_loaded() {
2269 if let Some(fk) = lazy.fk() {
2270 let fk_hash = hash_values(std::slice::from_ref(fk));
2271 let related = lookup.get(&fk_hash).cloned();
2272 let found = related.is_some();
2273 let _ = lazy.set_loaded(related);
2274 if found {
2275 loaded_count += 1;
2276 }
2277 } else {
2278 let _ = lazy.set_loaded(None);
2279 }
2280 }
2281 }
2282
2283 tracing::debug!(
2284 query_count = 1,
2285 loaded_count = loaded_count,
2286 "Batch load complete"
2287 );
2288
2289 Outcome::Ok(loaded_count)
2290 }
2291
2292 #[tracing::instrument(level = "debug", skip(self, cx, objects, accessor, parent_pk))]
2316 pub async fn load_many_to_many<P, Child, FA, FP>(
2317 &mut self,
2318 cx: &Cx,
2319 objects: &mut [P],
2320 accessor: FA,
2321 parent_pk: FP,
2322 link_table: &sqlmodel_core::LinkTableInfo,
2323 ) -> Outcome<usize, Error>
2324 where
2325 P: Model + 'static,
2326 Child: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
2327 FA: Fn(&mut P) -> &mut sqlmodel_core::RelatedMany<Child>,
2328 FP: Fn(&P) -> Value,
2329 {
2330 self.load_many_to_many_pk(cx, objects, accessor, |p| vec![parent_pk(p)], link_table)
2331 .await
2332 }
2333
2334 #[tracing::instrument(level = "debug", skip(self, cx, objects, accessor, parent_pk))]
2339 pub async fn load_many_to_many_pk<P, Child, FA, FP>(
2340 &mut self,
2341 cx: &Cx,
2342 objects: &mut [P],
2343 accessor: FA,
2344 parent_pk: FP,
2345 link_table: &sqlmodel_core::LinkTableInfo,
2346 ) -> Outcome<usize, Error>
2347 where
2348 P: Model + 'static,
2349 Child: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
2350 FA: Fn(&mut P) -> &mut sqlmodel_core::RelatedMany<Child>,
2351 FP: Fn(&P) -> Vec<Value>,
2352 {
2353 let mut pk_tuples: Vec<Vec<Value>> = Vec::with_capacity(objects.len());
2355 let mut pk_by_index: Vec<(usize, Vec<Value>)> = Vec::new();
2356 for (idx, obj) in objects.iter().enumerate() {
2357 let pk = parent_pk(obj);
2358 pk_tuples.push(pk.clone());
2359 pk_by_index.push((idx, pk));
2360 }
2361
2362 tracing::info!(
2363 parent_model = std::any::type_name::<P>(),
2364 related_model = std::any::type_name::<Child>(),
2365 parent_count = pk_tuples.len(),
2366 link_table = link_table.table_name,
2367 "Batch loading many-to-many relationships"
2368 );
2369
2370 if pk_tuples.is_empty() {
2371 return Outcome::Ok(0);
2372 }
2373
2374 let dialect = self.connection.dialect();
2380 let local_cols = link_table.local_cols();
2381 let remote_cols = link_table.remote_cols();
2382 if local_cols.is_empty() || remote_cols.is_empty() {
2383 return Outcome::Err(Error::Custom(
2384 "link_table must specify local/remote columns".to_string(),
2385 ));
2386 }
2387 if remote_cols.len() != Child::PRIMARY_KEY.len() {
2388 return Outcome::Err(Error::Custom(format!(
2389 "link_table remote cols count ({}) must match child PRIMARY_KEY len ({})",
2390 remote_cols.len(),
2391 Child::PRIMARY_KEY.len()
2392 )));
2393 }
2394
2395 let child_table = dialect.quote_identifier(Child::TABLE_NAME);
2396 let link_table_q = dialect.quote_identifier(link_table.table_name);
2397
2398 let parent_select_parts: String = local_cols
2399 .iter()
2400 .enumerate()
2401 .map(|(i, col)| {
2402 format!(
2403 "{link_table_q}.{} AS __parent_pk{}",
2404 dialect.quote_identifier(col),
2405 i
2406 )
2407 })
2408 .collect::<Vec<_>>()
2409 .join(", ");
2410
2411 let join_parts: String = remote_cols
2412 .iter()
2413 .zip(Child::PRIMARY_KEY.iter().copied())
2414 .map(|(link_col, child_col)| {
2415 format!(
2416 "{child_table}.{} = {link_table_q}.{}",
2417 dialect.quote_identifier(child_col),
2418 dialect.quote_identifier(link_col)
2419 )
2420 })
2421 .collect::<Vec<_>>()
2422 .join(" AND ");
2423
2424 let (where_sql, params) = if local_cols.len() == 1 {
2425 let mut params: Vec<Value> = Vec::with_capacity(pk_tuples.len());
2426 for t in &pk_tuples {
2427 if let Some(v) = t.first() {
2428 params.push(v.clone());
2429 }
2430 }
2431 let placeholders: Vec<String> =
2432 (1..=params.len()).map(|i| dialect.placeholder(i)).collect();
2433 let where_sql = format!(
2434 "{link_table_q}.{} IN ({})",
2435 dialect.quote_identifier(local_cols[0]),
2436 placeholders.join(", ")
2437 );
2438 (where_sql, params)
2439 } else {
2440 let mut tuples: Vec<Vec<Value>> = Vec::with_capacity(pk_tuples.len());
2441 for t in &pk_tuples {
2442 if t.len() == local_cols.len() {
2443 tuples.push(t.clone());
2444 }
2445 }
2446
2447 let mut params: Vec<Value> = Vec::with_capacity(tuples.len() * local_cols.len());
2448 let mut idx = 1;
2449 let tuple_sql: Vec<String> = tuples
2450 .iter()
2451 .map(|t| {
2452 for v in t {
2453 params.push(v.clone());
2454 }
2455 let inner = (0..local_cols.len())
2456 .map(|_| {
2457 let ph = dialect.placeholder(idx);
2458 idx += 1;
2459 ph
2460 })
2461 .collect::<Vec<_>>()
2462 .join(", ");
2463 format!("({})", inner)
2464 })
2465 .collect();
2466
2467 let col_list = local_cols
2468 .iter()
2469 .map(|c| format!("{link_table_q}.{}", dialect.quote_identifier(c)))
2470 .collect::<Vec<_>>()
2471 .join(", ");
2472
2473 let where_sql = format!("({}) IN ({})", col_list, tuple_sql.join(", "));
2474 (where_sql, params)
2475 };
2476
2477 let sql = format!(
2478 "SELECT {child_table}.*, {parent_select_parts} FROM {child_table} \
2479 JOIN {link_table_q} ON {join_parts} \
2480 WHERE {where_sql}"
2481 );
2482
2483 tracing::trace!(sql = %sql, "Many-to-many batch SQL");
2484
2485 let rows = match self.connection.query(cx, &sql, ¶ms).await {
2486 Outcome::Ok(rows) => rows,
2487 Outcome::Err(e) => return Outcome::Err(e),
2488 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
2489 Outcome::Panicked(p) => return Outcome::Panicked(p),
2490 };
2491
2492 let mut by_parent: HashMap<u64, Vec<Child>> = HashMap::new();
2494 for row in &rows {
2495 let mut parent_tuple: Vec<Value> = Vec::with_capacity(local_cols.len());
2497 let mut missing = false;
2498 for i in 0..local_cols.len() {
2499 let col = format!("__parent_pk{}", i);
2500 let Some(v) = row.get_by_name(&col) else {
2501 missing = true;
2502 break;
2503 };
2504 parent_tuple.push(v.clone());
2505 }
2506 if missing {
2507 continue;
2508 }
2509 let parent_pk_hash = hash_values(&parent_tuple);
2510
2511 match Child::from_row(row) {
2513 Ok(child) => {
2514 by_parent.entry(parent_pk_hash).or_default().push(child);
2515 }
2516 Err(_) => continue,
2517 }
2518 }
2519
2520 let mut loaded_count = 0;
2522 for (idx, pk_tuple) in pk_by_index {
2523 let pk_hash = hash_values(&pk_tuple);
2524 let children = by_parent.get(&pk_hash).cloned().unwrap_or_default();
2526 let child_count = children.len();
2527
2528 let related = accessor(&mut objects[idx]);
2529 if pk_tuple.len() == 1 {
2530 related.set_parent_pk(pk_tuple[0].clone());
2531 } else {
2532 related.set_parent_pk(Value::Array(pk_tuple.clone()));
2533 }
2534 let _ = related.set_loaded(children);
2535 loaded_count += child_count;
2536 }
2537
2538 tracing::debug!(
2539 query_count = 1,
2540 total_children = loaded_count,
2541 "Many-to-many batch load complete"
2542 );
2543
2544 Outcome::Ok(loaded_count)
2545 }
2546
2547 #[tracing::instrument(level = "debug", skip(self, cx, objects, accessor, parent_pk))]
2556 pub async fn load_one_to_many<P, Child, FA, FP>(
2557 &mut self,
2558 cx: &Cx,
2559 objects: &mut [P],
2560 accessor: FA,
2561 parent_pk: FP,
2562 ) -> Outcome<usize, Error>
2563 where
2564 P: Model + 'static,
2565 Child: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
2566 FA: Fn(&mut P) -> &mut sqlmodel_core::RelatedMany<Child>,
2567 FP: Fn(&P) -> Value,
2568 {
2569 let mut pks: Vec<Value> = Vec::new();
2571 let mut pk_by_index: Vec<(usize, Value)> = Vec::new();
2572 for (idx, obj) in objects.iter_mut().enumerate() {
2573 let pk = parent_pk(&*obj);
2574 let related = accessor(obj);
2575 if related.is_loaded() {
2576 continue;
2577 }
2578
2579 related.set_parent_pk(pk.clone());
2580
2581 if matches!(pk, Value::Null) {
2582 let _ = related.set_loaded(Vec::new());
2584 continue;
2585 }
2586
2587 pks.push(pk.clone());
2588 pk_by_index.push((idx, pk));
2589 }
2590
2591 tracing::info!(
2592 parent_model = std::any::type_name::<P>(),
2593 related_model = std::any::type_name::<Child>(),
2594 parent_count = objects.len(),
2595 query_parent_count = pks.len(),
2596 "Batch loading one-to-many relationships"
2597 );
2598
2599 if pks.is_empty() {
2600 return Outcome::Ok(0);
2601 }
2602
2603 let fk_column = accessor(&mut objects[pk_by_index[0].0]).fk_column();
2605 let dialect = self.connection.dialect();
2606 let placeholders: Vec<String> = (1..=pks.len()).map(|i| dialect.placeholder(i)).collect();
2607 let child_table = dialect.quote_identifier(Child::TABLE_NAME);
2608 let fk_q = dialect.quote_identifier(fk_column);
2609 let sql = format!(
2610 "SELECT *, {fk_q} AS __parent_pk FROM {child_table} WHERE {fk_q} IN ({})",
2611 placeholders.join(", ")
2612 );
2613
2614 tracing::trace!(sql = %sql, "One-to-many batch SQL");
2615
2616 let rows = match self.connection.query(cx, &sql, &pks).await {
2617 Outcome::Ok(rows) => rows,
2618 Outcome::Err(e) => return Outcome::Err(e),
2619 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
2620 Outcome::Panicked(p) => return Outcome::Panicked(p),
2621 };
2622
2623 let mut by_parent: HashMap<u64, Vec<Child>> = HashMap::new();
2625 for row in &rows {
2626 let parent_pk_value: Value = match row.get_by_name("__parent_pk") {
2627 Some(v) => v.clone(),
2628 None => continue,
2629 };
2630 let parent_pk_hash = hash_values(std::slice::from_ref(&parent_pk_value));
2631 match Child::from_row(row) {
2632 Ok(child) => {
2633 let pk_values = child.primary_key_value();
2635 let key = ObjectKey::from_pk::<Child>(&pk_values);
2636
2637 self.identity_map.entry(key).or_insert_with(|| {
2638 let row_data = child.to_row();
2640 let column_names: Vec<&'static str> =
2641 row_data.iter().map(|(name, _)| *name).collect();
2642 let values: Vec<Value> = row_data.into_iter().map(|(_, v)| v).collect();
2643
2644 let serialized = serde_json::to_vec(&values).ok();
2646
2647 TrackedObject {
2648 object: Box::new(child.clone()),
2649 original_state: serialized,
2650 state: ObjectState::Persistent,
2651 table_name: Child::TABLE_NAME,
2652 column_names,
2653 values,
2654 pk_columns: Child::PRIMARY_KEY.to_vec(),
2655 pk_values: pk_values.clone(),
2656 relationships: Child::RELATIONSHIPS,
2657 expired_attributes: None,
2658 }
2659 });
2660
2661 by_parent.entry(parent_pk_hash).or_default().push(child);
2662 }
2663 Err(_) => continue,
2664 }
2665 }
2666
2667 let mut loaded_count = 0;
2669 for (idx, pk) in pk_by_index {
2670 let pk_hash = hash_values(std::slice::from_ref(&pk));
2671 let children = by_parent.get(&pk_hash).cloned().unwrap_or_default();
2673 loaded_count += children.len();
2674
2675 let related = accessor(&mut objects[idx]);
2676 let _ = related.set_loaded(children);
2677 }
2678
2679 Outcome::Ok(loaded_count)
2680 }
2681
2682 #[tracing::instrument(level = "debug", skip(self, cx, objects, accessor, parent_pk))]
2701 pub async fn flush_related_many<P, Child, FA, FP>(
2702 &mut self,
2703 cx: &Cx,
2704 objects: &mut [P],
2705 accessor: FA,
2706 parent_pk: FP,
2707 link_table: &sqlmodel_core::LinkTableInfo,
2708 ) -> Outcome<usize, Error>
2709 where
2710 P: Model + 'static,
2711 Child: Model + 'static,
2712 FA: Fn(&mut P) -> &mut sqlmodel_core::RelatedMany<Child>,
2713 FP: Fn(&P) -> Value,
2714 {
2715 self.flush_related_many_pk(cx, objects, accessor, |p| vec![parent_pk(p)], link_table)
2716 .await
2717 }
2718
2719 #[tracing::instrument(level = "debug", skip(self, cx, objects, accessor, parent_pk))]
2721 pub async fn flush_related_many_pk<P, Child, FA, FP>(
2722 &mut self,
2723 cx: &Cx,
2724 objects: &mut [P],
2725 accessor: FA,
2726 parent_pk: FP,
2727 link_table: &sqlmodel_core::LinkTableInfo,
2728 ) -> Outcome<usize, Error>
2729 where
2730 P: Model + 'static,
2731 Child: Model + 'static,
2732 FA: Fn(&mut P) -> &mut sqlmodel_core::RelatedMany<Child>,
2733 FP: Fn(&P) -> Vec<Value>,
2734 {
2735 let mut ops = Vec::new();
2736 let local_cols = link_table.local_cols();
2737 let remote_cols = link_table.remote_cols();
2738 if local_cols.is_empty() || remote_cols.is_empty() {
2739 return Outcome::Err(Error::Custom(
2740 "link_table must specify local/remote columns".to_string(),
2741 ));
2742 }
2743
2744 for obj in objects.iter_mut() {
2746 let parent_pk_values = parent_pk(obj);
2747 if parent_pk_values.len() != local_cols.len() {
2748 return Outcome::Err(Error::Custom(format!(
2749 "parent_pk len ({}) must match link_table local cols len ({})",
2750 parent_pk_values.len(),
2751 local_cols.len()
2752 )));
2753 }
2754 let related = accessor(obj);
2755
2756 for child_pk_values in related.take_pending_links() {
2758 if child_pk_values.len() != remote_cols.len() {
2759 return Outcome::Err(Error::Custom(format!(
2760 "child pk len ({}) must match link_table remote cols len ({})",
2761 child_pk_values.len(),
2762 remote_cols.len()
2763 )));
2764 }
2765 ops.push(LinkTableOp::link_multi(
2766 link_table.table_name.to_string(),
2767 local_cols.iter().map(|c| (*c).to_string()).collect(),
2768 parent_pk_values.clone(),
2769 remote_cols.iter().map(|c| (*c).to_string()).collect(),
2770 child_pk_values,
2771 ));
2772 }
2773
2774 for child_pk_values in related.take_pending_unlinks() {
2776 if child_pk_values.len() != remote_cols.len() {
2777 return Outcome::Err(Error::Custom(format!(
2778 "child pk len ({}) must match link_table remote cols len ({})",
2779 child_pk_values.len(),
2780 remote_cols.len()
2781 )));
2782 }
2783 ops.push(LinkTableOp::unlink_multi(
2784 link_table.table_name.to_string(),
2785 local_cols.iter().map(|c| (*c).to_string()).collect(),
2786 parent_pk_values.clone(),
2787 remote_cols.iter().map(|c| (*c).to_string()).collect(),
2788 child_pk_values,
2789 ));
2790 }
2791 }
2792
2793 if ops.is_empty() {
2794 return Outcome::Ok(0);
2795 }
2796
2797 tracing::info!(
2798 parent_model = std::any::type_name::<P>(),
2799 related_model = std::any::type_name::<Child>(),
2800 link_count = ops
2801 .iter()
2802 .filter(|o| matches!(o, LinkTableOp::Link { .. }))
2803 .count(),
2804 unlink_count = ops
2805 .iter()
2806 .filter(|o| matches!(o, LinkTableOp::Unlink { .. }))
2807 .count(),
2808 link_table = link_table.table_name,
2809 "Flushing many-to-many relationship changes"
2810 );
2811
2812 flush::execute_link_table_ops(cx, &self.connection, &ops).await
2813 }
2814
2815 pub fn relate_to_one<Child, Parent, FC, FP, FK>(
2840 &self,
2841 child: &mut Child,
2842 child_accessor: FC,
2843 set_fk: FK,
2844 parent: &mut Parent,
2845 parent_accessor: FP,
2846 ) where
2847 Child: Model + Clone + 'static,
2848 Parent: Model + Clone + 'static,
2849 FC: FnOnce(&mut Child) -> &mut sqlmodel_core::Related<Parent>,
2850 FP: FnOnce(&mut Parent) -> &mut sqlmodel_core::RelatedMany<Child>,
2851 FK: FnOnce(&mut Child),
2852 {
2853 let related = child_accessor(child);
2855 let _ = related.set_loaded(Some(parent.clone()));
2856
2857 set_fk(child);
2859
2860 let related_many = parent_accessor(parent);
2862 related_many.link(child);
2863
2864 tracing::debug!(
2865 child_model = std::any::type_name::<Child>(),
2866 parent_model = std::any::type_name::<Parent>(),
2867 "Established bidirectional ManyToOne <-> OneToMany relationship"
2868 );
2869 }
2870
2871 pub fn unrelate_from_one<Child, Parent, FC, FP, FK>(
2887 &self,
2888 child: &mut Child,
2889 child_accessor: FC,
2890 clear_fk: FK,
2891 parent: &mut Parent,
2892 parent_accessor: FP,
2893 ) where
2894 Child: Model + Clone + 'static,
2895 Parent: Model + Clone + 'static,
2896 FC: FnOnce(&mut Child) -> &mut sqlmodel_core::Related<Parent>,
2897 FP: FnOnce(&mut Parent) -> &mut sqlmodel_core::RelatedMany<Child>,
2898 FK: FnOnce(&mut Child),
2899 {
2900 let related = child_accessor(child);
2902 *related = sqlmodel_core::Related::empty();
2903
2904 clear_fk(child);
2906
2907 let related_many = parent_accessor(parent);
2909 related_many.unlink(child);
2910
2911 tracing::debug!(
2912 child_model = std::any::type_name::<Child>(),
2913 parent_model = std::any::type_name::<Parent>(),
2914 "Removed bidirectional ManyToOne <-> OneToMany relationship"
2915 );
2916 }
2917
2918 pub fn relate_many_to_many<Left, Right, FL, FR>(
2937 &self,
2938 left: &mut Left,
2939 left_accessor: FL,
2940 right: &mut Right,
2941 right_accessor: FR,
2942 ) where
2943 Left: Model + Clone + 'static,
2944 Right: Model + Clone + 'static,
2945 FL: FnOnce(&mut Left) -> &mut sqlmodel_core::RelatedMany<Right>,
2946 FR: FnOnce(&mut Right) -> &mut sqlmodel_core::RelatedMany<Left>,
2947 {
2948 let left_coll = left_accessor(left);
2950 left_coll.link(right);
2951
2952 let right_coll = right_accessor(right);
2954 right_coll.link(left);
2955
2956 tracing::debug!(
2957 left_model = std::any::type_name::<Left>(),
2958 right_model = std::any::type_name::<Right>(),
2959 "Established bidirectional ManyToMany relationship"
2960 );
2961 }
2962
2963 pub fn unrelate_many_to_many<Left, Right, FL, FR>(
2967 &self,
2968 left: &mut Left,
2969 left_accessor: FL,
2970 right: &mut Right,
2971 right_accessor: FR,
2972 ) where
2973 Left: Model + Clone + 'static,
2974 Right: Model + Clone + 'static,
2975 FL: FnOnce(&mut Left) -> &mut sqlmodel_core::RelatedMany<Right>,
2976 FR: FnOnce(&mut Right) -> &mut sqlmodel_core::RelatedMany<Left>,
2977 {
2978 let left_coll = left_accessor(left);
2980 left_coll.unlink(right);
2981
2982 let right_coll = right_accessor(right);
2984 right_coll.unlink(left);
2985
2986 tracing::debug!(
2987 left_model = std::any::type_name::<Left>(),
2988 right_model = std::any::type_name::<Right>(),
2989 "Removed bidirectional ManyToMany relationship"
2990 );
2991 }
2992
2993 pub fn enable_n1_detection(&mut self, threshold: usize) {
3018 self.n1_tracker = Some(N1QueryTracker::new().with_threshold(threshold));
3019 }
3020
3021 pub fn disable_n1_detection(&mut self) {
3023 self.n1_tracker = None;
3024 }
3025
3026 #[must_use]
3028 pub fn n1_detection_enabled(&self) -> bool {
3029 self.n1_tracker.is_some()
3030 }
3031
3032 pub fn n1_tracker_mut(&mut self) -> Option<&mut N1QueryTracker> {
3034 self.n1_tracker.as_mut()
3035 }
3036
3037 #[must_use]
3039 pub fn n1_stats(&self) -> Option<N1Stats> {
3040 self.n1_tracker.as_ref().map(|t| t.stats())
3041 }
3042
3043 pub fn reset_n1_tracking(&mut self) {
3045 if let Some(tracker) = &mut self.n1_tracker {
3046 tracker.reset();
3047 }
3048 }
3049
3050 #[track_caller]
3054 pub fn record_lazy_load(&mut self, parent_type: &'static str, relationship: &'static str) {
3055 if let Some(tracker) = &mut self.n1_tracker {
3056 tracker.record_load(parent_type, relationship);
3057 }
3058 }
3059
3060 #[tracing::instrument(level = "debug", skip(self, cx, model), fields(table = M::TABLE_NAME))]
3104 pub async fn merge<
3105 M: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
3106 >(
3107 &mut self,
3108 cx: &Cx,
3109 model: M,
3110 load: bool,
3111 ) -> Outcome<M, Error> {
3112 let pk_values = model.primary_key_value();
3113 let key = ObjectKey::from_model(&model);
3114
3115 tracing::debug!(
3116 pk = ?pk_values,
3117 load = load,
3118 in_identity_map = self.identity_map.contains_key(&key),
3119 "Merging object"
3120 );
3121
3122 if let Some(tracked) = self.identity_map.get_mut(&key) {
3124 if tracked.state == ObjectState::Detached {
3126 tracing::debug!("Found detached object, treating as new");
3127 } else {
3128 tracing::debug!(
3129 state = ?tracked.state,
3130 "Found tracked object, updating with merged values"
3131 );
3132
3133 let row_data = model.to_row();
3135 tracked.object = Box::new(model.clone());
3136 tracked.column_names = row_data.iter().map(|(name, _)| *name).collect();
3137 tracked.values = row_data.into_iter().map(|(_, v)| v).collect();
3138 tracked.pk_values.clone_from(&pk_values);
3139
3140 if tracked.state == ObjectState::Persistent && !self.pending_dirty.contains(&key) {
3142 self.pending_dirty.push(key);
3143 }
3144
3145 if let Some(obj) = tracked.object.downcast_ref::<M>() {
3147 return Outcome::Ok(obj.clone());
3148 }
3149 }
3150 }
3151
3152 if load {
3154 let has_valid_pk = pk_values
3156 .iter()
3157 .all(|v| !matches!(v, Value::Null | Value::Default));
3158
3159 if has_valid_pk {
3160 tracing::debug!("Loading from database");
3161
3162 let db_result = self.get_by_pk::<M>(cx, &pk_values).await;
3163 match db_result {
3164 Outcome::Ok(Some(_existing)) => {
3165 if let Some(tracked) = self.identity_map.get_mut(&key) {
3168 let row_data = model.to_row();
3169 tracked.object = Box::new(model.clone());
3170 tracked.column_names = row_data.iter().map(|(name, _)| *name).collect();
3171 tracked.values = row_data.into_iter().map(|(_, v)| v).collect();
3172 if !self.pending_dirty.contains(&key) {
3176 self.pending_dirty.push(key);
3177 }
3178
3179 tracing::debug!("Merged values onto DB object");
3180
3181 if let Some(obj) = tracked.object.downcast_ref::<M>() {
3182 return Outcome::Ok(obj.clone());
3183 }
3184 }
3185 }
3186 Outcome::Ok(None) => {
3187 tracing::debug!("Object not found in database, treating as new");
3188 }
3189 Outcome::Err(e) => return Outcome::Err(e),
3190 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
3191 Outcome::Panicked(p) => return Outcome::Panicked(p),
3192 }
3193 }
3194 }
3195
3196 tracing::debug!("Adding as new object");
3198 self.add(&model);
3199
3200 Outcome::Ok(model)
3201 }
3202
3203 pub async fn merge_without_load<
3215 M: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
3216 >(
3217 &mut self,
3218 cx: &Cx,
3219 model: M,
3220 ) -> Outcome<M, Error> {
3221 self.merge(cx, model, false).await
3222 }
3223
3224 pub fn pending_new_count(&self) -> usize {
3230 self.pending_new.len()
3231 }
3232
3233 pub fn pending_delete_count(&self) -> usize {
3235 self.pending_delete.len()
3236 }
3237
3238 pub fn pending_dirty_count(&self) -> usize {
3240 self.pending_dirty.len()
3241 }
3242
3243 pub fn tracked_count(&self) -> usize {
3245 self.identity_map.len()
3246 }
3247
3248 pub fn in_transaction(&self) -> bool {
3250 self.in_transaction
3251 }
3252
3253 pub fn debug_state(&self) -> SessionDebugInfo {
3255 SessionDebugInfo {
3256 tracked: self.tracked_count(),
3257 pending_new: self.pending_new_count(),
3258 pending_delete: self.pending_delete_count(),
3259 pending_dirty: self.pending_dirty_count(),
3260 in_transaction: self.in_transaction,
3261 }
3262 }
3263
3264 pub async fn bulk_insert<M: Model + Clone + Send + Sync + 'static>(
3278 &mut self,
3279 cx: &Cx,
3280 models: &[M],
3281 ) -> Outcome<u64, Error> {
3282 self.bulk_insert_with_batch_size(cx, models, 1000).await
3283 }
3284
3285 pub async fn bulk_insert_with_batch_size<M: Model + Clone + Send + Sync + 'static>(
3287 &mut self,
3288 cx: &Cx,
3289 models: &[M],
3290 batch_size: usize,
3291 ) -> Outcome<u64, Error> {
3292 if models.is_empty() {
3293 return Outcome::Ok(0);
3294 }
3295
3296 let batch_size = batch_size.max(1);
3297 let mut total_inserted: u64 = 0;
3298
3299 for chunk in models.chunks(batch_size) {
3300 let builder = sqlmodel_query::InsertManyBuilder::new(chunk);
3301 match builder.execute(cx, &self.connection).await {
3302 Outcome::Ok(count) => total_inserted += count,
3303 Outcome::Err(e) => return Outcome::Err(e),
3304 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
3305 Outcome::Panicked(p) => return Outcome::Panicked(p),
3306 }
3307 }
3308
3309 Outcome::Ok(total_inserted)
3310 }
3311
3312 pub async fn bulk_update<M: Model + Clone + Send + Sync + 'static>(
3320 &mut self,
3321 cx: &Cx,
3322 models: &[M],
3323 ) -> Outcome<u64, Error> {
3324 if models.is_empty() {
3325 return Outcome::Ok(0);
3326 }
3327
3328 let mut total_updated: u64 = 0;
3329
3330 for model in models {
3331 let builder = sqlmodel_query::UpdateBuilder::new(model);
3332 let (sql, params) = builder.build_with_dialect(self.connection.dialect());
3333
3334 if sql.is_empty() {
3335 continue;
3336 }
3337
3338 match self.connection.execute(cx, &sql, ¶ms).await {
3339 Outcome::Ok(count) => total_updated += count,
3340 Outcome::Err(e) => return Outcome::Err(e),
3341 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
3342 Outcome::Panicked(p) => return Outcome::Panicked(p),
3343 }
3344 }
3345
3346 Outcome::Ok(total_updated)
3347 }
3348}
3349
3350impl<C, M> LazyLoader<M> for Session<C>
3351where
3352 C: Connection,
3353 M: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
3354{
3355 fn get(
3356 &mut self,
3357 cx: &Cx,
3358 pk: Value,
3359 ) -> impl Future<Output = Outcome<Option<M>, Error>> + Send {
3360 Session::get(self, cx, pk)
3361 }
3362}
3363
3364#[derive(Debug, Clone)]
3366pub struct SessionDebugInfo {
3367 pub tracked: usize,
3369 pub pending_new: usize,
3371 pub pending_delete: usize,
3373 pub pending_dirty: usize,
3375 pub in_transaction: bool,
3377}
3378
3379#[cfg(test)]
3384#[allow(clippy::manual_async_fn)] mod tests {
3386 use super::*;
3387 use asupersync::runtime::RuntimeBuilder;
3388 use sqlmodel_core::Row;
3389 use std::sync::{Arc, Mutex};
3390
3391 #[test]
3392 fn test_session_config_defaults() {
3393 let config = SessionConfig::default();
3394 assert!(config.auto_begin);
3395 assert!(!config.auto_flush);
3396 assert!(config.expire_on_commit);
3397 }
3398
3399 #[test]
3400 fn test_object_key_hash_consistency() {
3401 let values1 = vec![Value::BigInt(42)];
3402 let values2 = vec![Value::BigInt(42)];
3403 let hash1 = hash_values(&values1);
3404 let hash2 = hash_values(&values2);
3405 assert_eq!(hash1, hash2);
3406 }
3407
3408 #[test]
3409 fn test_object_key_hash_different_values() {
3410 let values1 = vec![Value::BigInt(42)];
3411 let values2 = vec![Value::BigInt(43)];
3412 let hash1 = hash_values(&values1);
3413 let hash2 = hash_values(&values2);
3414 assert_ne!(hash1, hash2);
3415 }
3416
3417 #[test]
3418 fn test_object_key_hash_different_types() {
3419 let values1 = vec![Value::BigInt(42)];
3420 let values2 = vec![Value::Text("42".to_string())];
3421 let hash1 = hash_values(&values1);
3422 let hash2 = hash_values(&values2);
3423 assert_ne!(hash1, hash2);
3424 }
3425
3426 #[test]
3427 fn test_session_debug_info() {
3428 let info = SessionDebugInfo {
3429 tracked: 5,
3430 pending_new: 2,
3431 pending_delete: 1,
3432 pending_dirty: 0,
3433 in_transaction: true,
3434 };
3435 assert_eq!(info.tracked, 5);
3436 assert_eq!(info.pending_new, 2);
3437 assert!(info.in_transaction);
3438 }
3439
3440 fn unwrap_outcome<T: std::fmt::Debug>(outcome: Outcome<T, Error>) -> T {
3441 match outcome {
3442 Outcome::Ok(v) => v,
3443 other => std::panic::panic_any(format!("unexpected outcome: {other:?}")),
3444 }
3445 }
3446
3447 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
3448 struct Team {
3449 id: Option<i64>,
3450 name: String,
3451 }
3452
3453 impl Model for Team {
3454 const TABLE_NAME: &'static str = "teams";
3455 const PRIMARY_KEY: &'static [&'static str] = &["id"];
3456
3457 fn fields() -> &'static [sqlmodel_core::FieldInfo] {
3458 &[]
3459 }
3460
3461 fn to_row(&self) -> Vec<(&'static str, Value)> {
3462 vec![
3463 ("id", self.id.map_or(Value::Null, Value::BigInt)),
3464 ("name", Value::Text(self.name.clone())),
3465 ]
3466 }
3467
3468 fn from_row(row: &Row) -> sqlmodel_core::Result<Self> {
3469 let id: i64 = row.get_named("id")?;
3470 let name: String = row.get_named("name")?;
3471 Ok(Self { id: Some(id), name })
3472 }
3473
3474 fn primary_key_value(&self) -> Vec<Value> {
3475 self.id
3476 .map_or_else(|| vec![Value::Null], |id| vec![Value::BigInt(id)])
3477 }
3478
3479 fn is_new(&self) -> bool {
3480 self.id.is_none()
3481 }
3482 }
3483
3484 #[derive(Debug, Clone, Serialize, Deserialize)]
3485 struct Hero {
3486 id: Option<i64>,
3487 team: Lazy<Team>,
3488 }
3489
3490 impl Model for Hero {
3491 const TABLE_NAME: &'static str = "heroes";
3492 const PRIMARY_KEY: &'static [&'static str] = &["id"];
3493
3494 fn fields() -> &'static [sqlmodel_core::FieldInfo] {
3495 &[]
3496 }
3497
3498 fn to_row(&self) -> Vec<(&'static str, Value)> {
3499 vec![]
3500 }
3501
3502 fn from_row(_row: &Row) -> sqlmodel_core::Result<Self> {
3503 Ok(Self {
3504 id: None,
3505 team: Lazy::empty(),
3506 })
3507 }
3508
3509 fn primary_key_value(&self) -> Vec<Value> {
3510 self.id
3511 .map_or_else(|| vec![Value::Null], |id| vec![Value::BigInt(id)])
3512 }
3513
3514 fn is_new(&self) -> bool {
3515 self.id.is_none()
3516 }
3517 }
3518
3519 #[derive(Debug, Default)]
3520 struct MockState {
3521 query_calls: usize,
3522 last_sql: Option<String>,
3523 execute_calls: usize,
3524 executed: Vec<(String, Vec<Value>)>,
3525 }
3526
3527 #[derive(Debug, Clone)]
3528 struct MockConnection {
3529 state: Arc<Mutex<MockState>>,
3530 dialect: sqlmodel_core::Dialect,
3531 }
3532
3533 impl MockConnection {
3534 fn new(state: Arc<Mutex<MockState>>) -> Self {
3535 Self {
3536 state,
3537 dialect: sqlmodel_core::Dialect::Postgres,
3538 }
3539 }
3540 }
3541
3542 impl sqlmodel_core::Connection for MockConnection {
3543 type Tx<'conn>
3544 = MockTransaction
3545 where
3546 Self: 'conn;
3547
3548 fn dialect(&self) -> sqlmodel_core::Dialect {
3549 self.dialect
3550 }
3551
3552 fn query(
3553 &self,
3554 _cx: &Cx,
3555 sql: &str,
3556 params: &[Value],
3557 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
3558 let params = params.to_vec();
3559 let state = Arc::clone(&self.state);
3560 let sql = sql.to_string();
3561 async move {
3562 {
3563 let mut guard = state.lock().expect("lock poisoned");
3564 guard.query_calls += 1;
3565 guard.last_sql = Some(sql.clone());
3566 }
3567
3568 let mut rows = Vec::new();
3569 let is_teams = sql.contains("teams");
3570 let is_heroes = sql.contains("heroes");
3571
3572 for v in params {
3573 if is_teams {
3574 match v {
3575 Value::BigInt(1) => rows.push(Row::new(
3576 vec!["id".into(), "name".into()],
3577 vec![Value::BigInt(1), Value::Text("Avengers".into())],
3578 )),
3579 Value::BigInt(2) => rows.push(Row::new(
3580 vec!["id".into(), "name".into()],
3581 vec![Value::BigInt(2), Value::Text("X-Men".into())],
3582 )),
3583 _ => {}
3584 }
3585 } else if is_heroes {
3586 match v {
3588 Value::BigInt(1) => {
3589 rows.push(Row::new(
3590 vec!["id".into(), "team_id".into(), "__parent_pk".into()],
3591 vec![Value::BigInt(101), Value::BigInt(1), Value::BigInt(1)],
3592 ));
3593 rows.push(Row::new(
3594 vec!["id".into(), "team_id".into(), "__parent_pk".into()],
3595 vec![Value::BigInt(102), Value::BigInt(1), Value::BigInt(1)],
3596 ));
3597 }
3598 Value::BigInt(2) => rows.push(Row::new(
3599 vec!["id".into(), "team_id".into(), "__parent_pk".into()],
3600 vec![Value::BigInt(201), Value::BigInt(2), Value::BigInt(2)],
3601 )),
3602 _ => {}
3603 }
3604 }
3605 }
3606
3607 Outcome::Ok(rows)
3608 }
3609 }
3610
3611 fn query_one(
3612 &self,
3613 _cx: &Cx,
3614 _sql: &str,
3615 _params: &[Value],
3616 ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
3617 async { Outcome::Ok(None) }
3618 }
3619
3620 fn execute(
3621 &self,
3622 _cx: &Cx,
3623 sql: &str,
3624 params: &[Value],
3625 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
3626 let state = Arc::clone(&self.state);
3627 let sql = sql.to_string();
3628 let params = params.to_vec();
3629 async move {
3630 let mut guard = state.lock().expect("lock poisoned");
3631 guard.execute_calls += 1;
3632 guard.executed.push((sql, params));
3633 Outcome::Ok(0)
3634 }
3635 }
3636
3637 fn insert(
3638 &self,
3639 _cx: &Cx,
3640 _sql: &str,
3641 _params: &[Value],
3642 ) -> impl Future<Output = Outcome<i64, Error>> + Send {
3643 async { Outcome::Ok(0) }
3644 }
3645
3646 fn batch(
3647 &self,
3648 _cx: &Cx,
3649 _statements: &[(String, Vec<Value>)],
3650 ) -> impl Future<Output = Outcome<Vec<u64>, Error>> + Send {
3651 async { Outcome::Ok(vec![]) }
3652 }
3653
3654 fn begin(&self, _cx: &Cx) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
3655 async { Outcome::Ok(MockTransaction) }
3656 }
3657
3658 fn begin_with(
3659 &self,
3660 _cx: &Cx,
3661 _isolation: sqlmodel_core::connection::IsolationLevel,
3662 ) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
3663 async { Outcome::Ok(MockTransaction) }
3664 }
3665
3666 fn prepare(
3667 &self,
3668 _cx: &Cx,
3669 _sql: &str,
3670 ) -> impl Future<Output = Outcome<sqlmodel_core::connection::PreparedStatement, Error>> + Send
3671 {
3672 async {
3673 Outcome::Ok(sqlmodel_core::connection::PreparedStatement::new(
3674 0,
3675 String::new(),
3676 0,
3677 ))
3678 }
3679 }
3680
3681 fn query_prepared(
3682 &self,
3683 _cx: &Cx,
3684 _stmt: &sqlmodel_core::connection::PreparedStatement,
3685 _params: &[Value],
3686 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
3687 async { Outcome::Ok(vec![]) }
3688 }
3689
3690 fn execute_prepared(
3691 &self,
3692 _cx: &Cx,
3693 _stmt: &sqlmodel_core::connection::PreparedStatement,
3694 _params: &[Value],
3695 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
3696 async { Outcome::Ok(0) }
3697 }
3698
3699 fn ping(&self, _cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
3700 async { Outcome::Ok(()) }
3701 }
3702
3703 fn close(self, _cx: &Cx) -> impl Future<Output = sqlmodel_core::Result<()>> + Send {
3704 async { Ok(()) }
3705 }
3706 }
3707
3708 struct MockTransaction;
3709
3710 impl sqlmodel_core::connection::TransactionOps for MockTransaction {
3711 fn query(
3712 &self,
3713 _cx: &Cx,
3714 _sql: &str,
3715 _params: &[Value],
3716 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
3717 async { Outcome::Ok(vec![]) }
3718 }
3719
3720 fn query_one(
3721 &self,
3722 _cx: &Cx,
3723 _sql: &str,
3724 _params: &[Value],
3725 ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
3726 async { Outcome::Ok(None) }
3727 }
3728
3729 fn execute(
3730 &self,
3731 _cx: &Cx,
3732 _sql: &str,
3733 _params: &[Value],
3734 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
3735 async { Outcome::Ok(0) }
3736 }
3737
3738 fn savepoint(
3739 &self,
3740 _cx: &Cx,
3741 _name: &str,
3742 ) -> impl Future<Output = Outcome<(), Error>> + Send {
3743 async { Outcome::Ok(()) }
3744 }
3745
3746 fn rollback_to(
3747 &self,
3748 _cx: &Cx,
3749 _name: &str,
3750 ) -> impl Future<Output = Outcome<(), Error>> + Send {
3751 async { Outcome::Ok(()) }
3752 }
3753
3754 fn release(
3755 &self,
3756 _cx: &Cx,
3757 _name: &str,
3758 ) -> impl Future<Output = Outcome<(), Error>> + Send {
3759 async { Outcome::Ok(()) }
3760 }
3761
3762 fn commit(self, _cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
3763 async { Outcome::Ok(()) }
3764 }
3765
3766 fn rollback(self, _cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
3767 async { Outcome::Ok(()) }
3768 }
3769 }
3770
3771 #[test]
3772 fn test_load_many_single_query_and_populates_lazy() {
3773 let rt = RuntimeBuilder::current_thread()
3774 .build()
3775 .expect("create asupersync runtime");
3776 let cx = Cx::for_testing();
3777
3778 let state = Arc::new(Mutex::new(MockState::default()));
3779 let conn = MockConnection::new(Arc::clone(&state));
3780 let mut session = Session::new(conn);
3781
3782 let heroes = vec![
3783 Hero {
3784 id: Some(1),
3785 team: Lazy::from_fk(1_i64),
3786 },
3787 Hero {
3788 id: Some(2),
3789 team: Lazy::from_fk(2_i64),
3790 },
3791 Hero {
3792 id: Some(3),
3793 team: Lazy::from_fk(1_i64),
3794 },
3795 Hero {
3796 id: Some(4),
3797 team: Lazy::empty(),
3798 },
3799 Hero {
3800 id: Some(5),
3801 team: Lazy::from_fk(999_i64),
3802 },
3803 ];
3804
3805 rt.block_on(async {
3806 let loaded = unwrap_outcome(
3807 session
3808 .load_many::<Hero, Team, _>(&cx, &heroes, |h| &h.team)
3809 .await,
3810 );
3811 assert_eq!(loaded, 3);
3812
3813 assert!(heroes[0].team.is_loaded());
3815 assert_eq!(heroes[0].team.get().unwrap().name, "Avengers");
3816 assert_eq!(heroes[1].team.get().unwrap().name, "X-Men");
3817 assert_eq!(heroes[2].team.get().unwrap().name, "Avengers");
3818
3819 assert!(heroes[3].team.is_loaded());
3821 assert!(heroes[3].team.get().is_none());
3822
3823 assert!(heroes[4].team.is_loaded());
3825 assert!(heroes[4].team.get().is_none());
3826
3827 let team1 = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await);
3829 assert_eq!(
3830 team1,
3831 Some(Team {
3832 id: Some(1),
3833 name: "Avengers".to_string()
3834 })
3835 );
3836 });
3837
3838 assert_eq!(state.lock().expect("lock poisoned").query_calls, 1);
3839 }
3840
3841 #[derive(Debug, Clone, Serialize, Deserialize)]
3842 struct HeroChild {
3843 id: Option<i64>,
3844 team_id: i64,
3845 }
3846
3847 impl Model for HeroChild {
3848 const TABLE_NAME: &'static str = "heroes";
3849 const PRIMARY_KEY: &'static [&'static str] = &["id"];
3850
3851 fn fields() -> &'static [sqlmodel_core::FieldInfo] {
3852 &[]
3853 }
3854
3855 fn to_row(&self) -> Vec<(&'static str, Value)> {
3856 vec![
3857 ("id", self.id.map_or(Value::Null, Value::BigInt)),
3858 ("team_id", Value::BigInt(self.team_id)),
3859 ]
3860 }
3861
3862 fn from_row(row: &Row) -> sqlmodel_core::Result<Self> {
3863 let id: i64 = row.get_named("id")?;
3864 let team_id: i64 = row.get_named("team_id")?;
3865 Ok(Self {
3866 id: Some(id),
3867 team_id,
3868 })
3869 }
3870
3871 fn primary_key_value(&self) -> Vec<Value> {
3872 self.id
3873 .map_or_else(|| vec![Value::Null], |id| vec![Value::BigInt(id)])
3874 }
3875
3876 fn is_new(&self) -> bool {
3877 self.id.is_none()
3878 }
3879 }
3880
3881 #[derive(Debug, Clone, Serialize, Deserialize)]
3882 struct TeamWithHeroes {
3883 id: Option<i64>,
3884 heroes: sqlmodel_core::RelatedMany<HeroChild>,
3885 }
3886
3887 impl Model for TeamWithHeroes {
3888 const TABLE_NAME: &'static str = "teams";
3889 const PRIMARY_KEY: &'static [&'static str] = &["id"];
3890 const RELATIONSHIPS: &'static [sqlmodel_core::RelationshipInfo] =
3891 &[sqlmodel_core::RelationshipInfo::new(
3892 "heroes",
3893 "heroes",
3894 sqlmodel_core::RelationshipKind::OneToMany,
3895 )
3896 .remote_key("team_id")
3897 .cascade_delete(true)];
3898
3899 fn fields() -> &'static [sqlmodel_core::FieldInfo] {
3900 &[]
3901 }
3902
3903 fn to_row(&self) -> Vec<(&'static str, Value)> {
3904 vec![("id", self.id.map_or(Value::Null, Value::BigInt))]
3905 }
3906
3907 fn from_row(row: &Row) -> sqlmodel_core::Result<Self> {
3908 let id: i64 = row.get_named("id")?;
3909 Ok(Self {
3910 id: Some(id),
3911 heroes: sqlmodel_core::RelatedMany::new("team_id"),
3912 })
3913 }
3914
3915 fn primary_key_value(&self) -> Vec<Value> {
3916 self.id
3917 .map_or_else(|| vec![Value::Null], |id| vec![Value::BigInt(id)])
3918 }
3919
3920 fn is_new(&self) -> bool {
3921 self.id.is_none()
3922 }
3923 }
3924
3925 #[derive(Debug, Clone, Serialize, Deserialize)]
3926 struct TeamWithHeroesPassive {
3927 id: Option<i64>,
3928 heroes: sqlmodel_core::RelatedMany<HeroChild>,
3929 }
3930
3931 impl Model for TeamWithHeroesPassive {
3932 const TABLE_NAME: &'static str = "teams_passive";
3933 const PRIMARY_KEY: &'static [&'static str] = &["id"];
3934 const RELATIONSHIPS: &'static [sqlmodel_core::RelationshipInfo] =
3935 &[sqlmodel_core::RelationshipInfo::new(
3936 "heroes",
3937 "heroes",
3938 sqlmodel_core::RelationshipKind::OneToMany,
3939 )
3940 .remote_key("team_id")
3941 .cascade_delete(true)
3942 .passive_deletes(sqlmodel_core::PassiveDeletes::Passive)];
3943
3944 fn fields() -> &'static [sqlmodel_core::FieldInfo] {
3945 &[]
3946 }
3947
3948 fn to_row(&self) -> Vec<(&'static str, Value)> {
3949 vec![("id", self.id.map_or(Value::Null, Value::BigInt))]
3950 }
3951
3952 fn from_row(row: &Row) -> sqlmodel_core::Result<Self> {
3953 let id: i64 = row.get_named("id")?;
3954 Ok(Self {
3955 id: Some(id),
3956 heroes: sqlmodel_core::RelatedMany::new("team_id"),
3957 })
3958 }
3959
3960 fn primary_key_value(&self) -> Vec<Value> {
3961 self.id
3962 .map_or_else(|| vec![Value::Null], |id| vec![Value::BigInt(id)])
3963 }
3964
3965 fn is_new(&self) -> bool {
3966 self.id.is_none()
3967 }
3968 }
3969
3970 #[derive(Debug, Clone, Serialize, Deserialize)]
3971 struct HeroCompositeChild {
3972 id: Option<i64>,
3973 team_id1: i64,
3974 team_id2: i64,
3975 }
3976
3977 impl Model for HeroCompositeChild {
3978 const TABLE_NAME: &'static str = "heroes_composite";
3979 const PRIMARY_KEY: &'static [&'static str] = &["id"];
3980
3981 fn fields() -> &'static [sqlmodel_core::FieldInfo] {
3982 &[]
3983 }
3984
3985 fn to_row(&self) -> Vec<(&'static str, Value)> {
3986 vec![
3987 ("id", self.id.map_or(Value::Null, Value::BigInt)),
3988 ("team_id1", Value::BigInt(self.team_id1)),
3989 ("team_id2", Value::BigInt(self.team_id2)),
3990 ]
3991 }
3992
3993 fn from_row(row: &Row) -> sqlmodel_core::Result<Self> {
3994 let id: i64 = row.get_named("id")?;
3995 let team_id1: i64 = row.get_named("team_id1")?;
3996 let team_id2: i64 = row.get_named("team_id2")?;
3997 Ok(Self {
3998 id: Some(id),
3999 team_id1,
4000 team_id2,
4001 })
4002 }
4003
4004 fn primary_key_value(&self) -> Vec<Value> {
4005 self.id
4006 .map_or_else(|| vec![Value::Null], |id| vec![Value::BigInt(id)])
4007 }
4008
4009 fn is_new(&self) -> bool {
4010 self.id.is_none()
4011 }
4012 }
4013
4014 #[derive(Debug, Clone, Serialize, Deserialize)]
4015 struct TeamComposite {
4016 id1: Option<i64>,
4017 id2: Option<i64>,
4018 }
4019
4020 impl Model for TeamComposite {
4021 const TABLE_NAME: &'static str = "teams_composite";
4022 const PRIMARY_KEY: &'static [&'static str] = &["id1", "id2"];
4023 const RELATIONSHIPS: &'static [sqlmodel_core::RelationshipInfo] =
4024 &[sqlmodel_core::RelationshipInfo::new(
4025 "heroes",
4026 "heroes_composite",
4027 sqlmodel_core::RelationshipKind::OneToMany,
4028 )
4029 .remote_keys(&["team_id1", "team_id2"])
4030 .cascade_delete(true)];
4031
4032 fn fields() -> &'static [sqlmodel_core::FieldInfo] {
4033 &[]
4034 }
4035
4036 fn to_row(&self) -> Vec<(&'static str, Value)> {
4037 vec![
4038 ("id1", self.id1.map_or(Value::Null, Value::BigInt)),
4039 ("id2", self.id2.map_or(Value::Null, Value::BigInt)),
4040 ]
4041 }
4042
4043 fn from_row(row: &Row) -> sqlmodel_core::Result<Self> {
4044 let id1: i64 = row.get_named("id1")?;
4045 let id2: i64 = row.get_named("id2")?;
4046 Ok(Self {
4047 id1: Some(id1),
4048 id2: Some(id2),
4049 })
4050 }
4051
4052 fn primary_key_value(&self) -> Vec<Value> {
4053 match (self.id1, self.id2) {
4054 (Some(a), Some(b)) => vec![Value::BigInt(a), Value::BigInt(b)],
4055 _ => vec![Value::Null, Value::Null],
4056 }
4057 }
4058
4059 fn is_new(&self) -> bool {
4060 self.id1.is_none() || self.id2.is_none()
4061 }
4062 }
4063
4064 #[derive(Debug, Clone, Serialize, Deserialize)]
4065 struct TeamCompositePassive {
4066 id1: Option<i64>,
4067 id2: Option<i64>,
4068 }
4069
4070 impl Model for TeamCompositePassive {
4071 const TABLE_NAME: &'static str = "teams_composite_passive";
4072 const PRIMARY_KEY: &'static [&'static str] = &["id1", "id2"];
4073 const RELATIONSHIPS: &'static [sqlmodel_core::RelationshipInfo] =
4074 &[sqlmodel_core::RelationshipInfo::new(
4075 "heroes",
4076 "heroes_composite",
4077 sqlmodel_core::RelationshipKind::OneToMany,
4078 )
4079 .remote_keys(&["team_id1", "team_id2"])
4080 .cascade_delete(true)
4081 .passive_deletes(sqlmodel_core::PassiveDeletes::Passive)];
4082
4083 fn fields() -> &'static [sqlmodel_core::FieldInfo] {
4084 &[]
4085 }
4086
4087 fn to_row(&self) -> Vec<(&'static str, Value)> {
4088 vec![
4089 ("id1", self.id1.map_or(Value::Null, Value::BigInt)),
4090 ("id2", self.id2.map_or(Value::Null, Value::BigInt)),
4091 ]
4092 }
4093
4094 fn from_row(row: &Row) -> sqlmodel_core::Result<Self> {
4095 let id1: i64 = row.get_named("id1")?;
4096 let id2: i64 = row.get_named("id2")?;
4097 Ok(Self {
4098 id1: Some(id1),
4099 id2: Some(id2),
4100 })
4101 }
4102
4103 fn primary_key_value(&self) -> Vec<Value> {
4104 match (self.id1, self.id2) {
4105 (Some(a), Some(b)) => vec![Value::BigInt(a), Value::BigInt(b)],
4106 _ => vec![Value::Null, Value::Null],
4107 }
4108 }
4109
4110 fn is_new(&self) -> bool {
4111 self.id1.is_none() || self.id2.is_none()
4112 }
4113 }
4114
4115 #[test]
4116 fn test_load_one_to_many_single_query_and_populates_related_many() {
4117 let rt = RuntimeBuilder::current_thread()
4118 .build()
4119 .expect("create asupersync runtime");
4120 let cx = Cx::for_testing();
4121
4122 let state = Arc::new(Mutex::new(MockState::default()));
4123 let conn = MockConnection::new(Arc::clone(&state));
4124 let mut session = Session::new(conn);
4125
4126 let mut teams = vec![
4127 TeamWithHeroes {
4128 id: Some(1),
4129 heroes: sqlmodel_core::RelatedMany::new("team_id"),
4130 },
4131 TeamWithHeroes {
4132 id: Some(2),
4133 heroes: sqlmodel_core::RelatedMany::new("team_id"),
4134 },
4135 TeamWithHeroes {
4136 id: None,
4137 heroes: sqlmodel_core::RelatedMany::new("team_id"),
4138 },
4139 ];
4140
4141 rt.block_on(async {
4142 let loaded = unwrap_outcome(
4143 session
4144 .load_one_to_many::<TeamWithHeroes, HeroChild, _, _>(
4145 &cx,
4146 &mut teams,
4147 |t| &mut t.heroes,
4148 |t| t.id.map_or(Value::Null, Value::BigInt),
4149 )
4150 .await,
4151 );
4152 assert_eq!(loaded, 3);
4153
4154 assert!(teams[0].heroes.is_loaded());
4155 assert_eq!(teams[0].heroes.len(), 2);
4156 assert_eq!(teams[0].heroes.parent_pk(), Some(&Value::BigInt(1)));
4157
4158 assert!(teams[1].heroes.is_loaded());
4159 assert_eq!(teams[1].heroes.len(), 1);
4160 assert_eq!(teams[1].heroes.parent_pk(), Some(&Value::BigInt(2)));
4161
4162 assert!(teams[2].heroes.is_loaded());
4164 assert_eq!(teams[2].heroes.len(), 0);
4165 assert_eq!(teams[2].heroes.parent_pk(), Some(&Value::Null));
4166 });
4167
4168 assert_eq!(state.lock().expect("lock poisoned").query_calls, 1);
4169 let sql = state
4170 .lock()
4171 .expect("lock poisoned")
4172 .last_sql
4173 .clone()
4174 .expect("sql captured");
4175 assert!(sql.contains("FROM"), "expected SQL to contain FROM");
4176 assert!(
4177 sql.contains("heroes"),
4178 "expected SQL to target heroes table"
4179 );
4180 assert!(
4181 sql.contains("$1"),
4182 "expected Postgres-style placeholders ($1, $2, ...)"
4183 );
4184 assert!(
4185 sql.contains("$2"),
4186 "expected Postgres-style placeholders ($1, $2, ...)"
4187 );
4188 }
4189
4190 #[test]
4191 fn test_flush_cascade_delete_one_to_many_deletes_children_first() {
4192 let rt = RuntimeBuilder::current_thread()
4193 .build()
4194 .expect("create asupersync runtime");
4195 let cx = Cx::for_testing();
4196
4197 let state = Arc::new(Mutex::new(MockState::default()));
4198 let conn = MockConnection::new(Arc::clone(&state));
4199 let mut session = Session::with_config(
4200 conn,
4201 SessionConfig {
4202 auto_begin: false,
4203 auto_flush: false,
4204 expire_on_commit: true,
4205 },
4206 );
4207
4208 rt.block_on(async {
4209 let team = unwrap_outcome(session.get::<TeamWithHeroes>(&cx, 1_i64).await).unwrap();
4211
4212 let mut teams = vec![team.clone()];
4214 let loaded = unwrap_outcome(
4215 session
4216 .load_one_to_many::<TeamWithHeroes, HeroChild, _, _>(
4217 &cx,
4218 &mut teams,
4219 |t| &mut t.heroes,
4220 |t| t.id.map_or(Value::Null, Value::BigInt),
4221 )
4222 .await,
4223 );
4224 assert_eq!(loaded, 2);
4225
4226 session.delete(&team);
4228 unwrap_outcome(session.flush(&cx).await);
4229
4230 assert_eq!(session.tracked_count(), 0);
4232 });
4233
4234 let guard = state.lock().expect("lock poisoned");
4235 assert!(
4236 guard.execute_calls >= 2,
4237 "expected at least cascade + parent delete"
4238 );
4239 let (sql0, _params0) = &guard.executed[0];
4240 let (sql1, _params1) = &guard.executed[1];
4241 assert!(
4242 sql0.contains("DELETE") && sql0.contains("heroes"),
4243 "expected first delete to target child table"
4244 );
4245 assert!(
4246 sql1.contains("DELETE") && sql1.contains("teams"),
4247 "expected second delete to target parent table"
4248 );
4249 }
4250
4251 #[test]
4252 fn test_flush_passive_deletes_does_not_emit_child_delete_but_detaches_children() {
4253 let rt = RuntimeBuilder::current_thread()
4254 .build()
4255 .expect("create asupersync runtime");
4256 let cx = Cx::for_testing();
4257
4258 let state = Arc::new(Mutex::new(MockState::default()));
4259 let conn = MockConnection::new(Arc::clone(&state));
4260 let mut session = Session::with_config(
4261 conn,
4262 SessionConfig {
4263 auto_begin: false,
4264 auto_flush: false,
4265 expire_on_commit: true,
4266 },
4267 );
4268
4269 rt.block_on(async {
4270 let team =
4271 unwrap_outcome(session.get::<TeamWithHeroesPassive>(&cx, 1_i64).await).unwrap();
4272
4273 let mut teams = vec![team.clone()];
4275 let loaded = unwrap_outcome(
4276 session
4277 .load_one_to_many::<TeamWithHeroesPassive, HeroChild, _, _>(
4278 &cx,
4279 &mut teams,
4280 |t| &mut t.heroes,
4281 |t| t.id.map_or(Value::Null, Value::BigInt),
4282 )
4283 .await,
4284 );
4285 assert_eq!(loaded, 2);
4286
4287 session.delete(&team);
4288 unwrap_outcome(session.flush(&cx).await);
4289
4290 assert_eq!(session.tracked_count(), 0);
4291 });
4292
4293 let guard = state.lock().expect("lock poisoned");
4294 assert_eq!(guard.execute_calls, 1, "expected only the parent delete");
4295 let (sql0, _params0) = &guard.executed[0];
4296 assert!(
4297 sql0.contains("teams_passive"),
4298 "expected delete to target parent table"
4299 );
4300 assert!(
4301 !sql0.contains("heroes"),
4302 "did not expect a child-table delete for passive_deletes"
4303 );
4304 }
4305
4306 #[test]
4307 fn test_flush_cascade_delete_composite_keys_deletes_children_first() {
4308 let rt = RuntimeBuilder::current_thread()
4309 .build()
4310 .expect("create asupersync runtime");
4311 let cx = Cx::for_testing();
4312
4313 let state = Arc::new(Mutex::new(MockState::default()));
4314 let conn = MockConnection::new(Arc::clone(&state));
4315 let mut session = Session::with_config(
4316 conn,
4317 SessionConfig {
4318 auto_begin: false,
4319 auto_flush: false,
4320 expire_on_commit: true,
4321 },
4322 );
4323
4324 let team = TeamComposite {
4325 id1: Some(1),
4326 id2: Some(2),
4327 };
4328 let team_key = ObjectKey::from_model(&team);
4329
4330 session.identity_map.insert(
4332 team_key,
4333 TrackedObject {
4334 object: Box::new(team.clone()),
4335 original_state: None,
4336 state: ObjectState::Persistent,
4337 table_name: TeamComposite::TABLE_NAME,
4338 column_names: vec!["id1", "id2"],
4339 values: vec![Value::BigInt(1), Value::BigInt(2)],
4340 pk_columns: vec!["id1", "id2"],
4341 pk_values: vec![Value::BigInt(1), Value::BigInt(2)],
4342 relationships: TeamComposite::RELATIONSHIPS,
4343 expired_attributes: None,
4344 },
4345 );
4346
4347 let child1 = HeroCompositeChild {
4349 id: Some(10),
4350 team_id1: 1,
4351 team_id2: 2,
4352 };
4353 let child2 = HeroCompositeChild {
4354 id: Some(11),
4355 team_id1: 1,
4356 team_id2: 2,
4357 };
4358 for child in [child1, child2] {
4359 let child_id = child.id.expect("child id");
4360 let key = ObjectKey::from_model(&child);
4361 session.identity_map.insert(
4362 key,
4363 TrackedObject {
4364 object: Box::new(child),
4365 original_state: None,
4366 state: ObjectState::Persistent,
4367 table_name: HeroCompositeChild::TABLE_NAME,
4368 column_names: vec!["id", "team_id1", "team_id2"],
4369 values: vec![Value::BigInt(child_id), Value::BigInt(1), Value::BigInt(2)],
4370 pk_columns: vec!["id"],
4371 pk_values: vec![Value::BigInt(child_id)],
4372 relationships: HeroCompositeChild::RELATIONSHIPS,
4373 expired_attributes: None,
4374 },
4375 );
4376 }
4377
4378 rt.block_on(async {
4379 session.delete(&team);
4380 unwrap_outcome(session.flush(&cx).await);
4381 assert_eq!(session.tracked_count(), 0);
4382 });
4383
4384 let guard = state.lock().expect("lock poisoned");
4385 assert!(
4386 guard.execute_calls >= 2,
4387 "expected at least composite cascade + parent delete"
4388 );
4389 let (sql0, _params0) = &guard.executed[0];
4390 assert!(sql0.contains("DELETE"), "expected DELETE SQL");
4391 assert!(
4392 sql0.contains("heroes_composite"),
4393 "expected composite cascade to target child table"
4394 );
4395 assert!(sql0.contains("team_id1"), "expected fk col team_id1");
4396 assert!(sql0.contains("team_id2"), "expected fk col team_id2");
4397 assert!(
4398 sql0.contains("$1") && sql0.contains("$2"),
4399 "expected Postgres-style placeholders for composite tuple"
4400 );
4401 }
4402
4403 #[test]
4404 fn test_flush_passive_deletes_composite_keys_detaches_children_no_child_delete_sql() {
4405 let rt = RuntimeBuilder::current_thread()
4406 .build()
4407 .expect("create asupersync runtime");
4408 let cx = Cx::for_testing();
4409
4410 let state = Arc::new(Mutex::new(MockState::default()));
4411 let conn = MockConnection::new(Arc::clone(&state));
4412 let mut session = Session::with_config(
4413 conn,
4414 SessionConfig {
4415 auto_begin: false,
4416 auto_flush: false,
4417 expire_on_commit: true,
4418 },
4419 );
4420
4421 let team = TeamCompositePassive {
4422 id1: Some(1),
4423 id2: Some(2),
4424 };
4425 let team_key = ObjectKey::from_model(&team);
4426
4427 session.identity_map.insert(
4428 team_key,
4429 TrackedObject {
4430 object: Box::new(team.clone()),
4431 original_state: None,
4432 state: ObjectState::Persistent,
4433 table_name: TeamCompositePassive::TABLE_NAME,
4434 column_names: vec!["id1", "id2"],
4435 values: vec![Value::BigInt(1), Value::BigInt(2)],
4436 pk_columns: vec!["id1", "id2"],
4437 pk_values: vec![Value::BigInt(1), Value::BigInt(2)],
4438 relationships: TeamCompositePassive::RELATIONSHIPS,
4439 expired_attributes: None,
4440 },
4441 );
4442
4443 let child = HeroCompositeChild {
4444 id: Some(10),
4445 team_id1: 1,
4446 team_id2: 2,
4447 };
4448 session.identity_map.insert(
4449 ObjectKey::from_model(&child),
4450 TrackedObject {
4451 object: Box::new(child),
4452 original_state: None,
4453 state: ObjectState::Persistent,
4454 table_name: HeroCompositeChild::TABLE_NAME,
4455 column_names: vec!["id", "team_id1", "team_id2"],
4456 values: vec![Value::BigInt(10), Value::BigInt(1), Value::BigInt(2)],
4457 pk_columns: vec!["id"],
4458 pk_values: vec![Value::BigInt(10)],
4459 relationships: HeroCompositeChild::RELATIONSHIPS,
4460 expired_attributes: None,
4461 },
4462 );
4463
4464 rt.block_on(async {
4465 session.delete(&team);
4466 unwrap_outcome(session.flush(&cx).await);
4467 assert_eq!(session.tracked_count(), 0);
4468 });
4469
4470 let guard = state.lock().expect("lock poisoned");
4471 assert_eq!(guard.execute_calls, 1, "expected only the parent delete");
4472 let (sql0, _params0) = &guard.executed[0];
4473 assert!(
4474 sql0.contains("teams_composite_passive"),
4475 "expected delete to target composite parent table"
4476 );
4477 assert!(
4478 !sql0.contains("heroes_composite"),
4479 "did not expect a child-table delete for passive_deletes"
4480 );
4481 }
4482
4483 #[derive(Debug, Clone, Serialize, Deserialize)]
4484 struct MmChildComposite {
4485 id1: i64,
4486 id2: i64,
4487 }
4488
4489 impl Model for MmChildComposite {
4490 const TABLE_NAME: &'static str = "mm_children";
4491 const PRIMARY_KEY: &'static [&'static str] = &["id1", "id2"];
4492
4493 fn fields() -> &'static [sqlmodel_core::FieldInfo] {
4494 &[]
4495 }
4496
4497 fn to_row(&self) -> Vec<(&'static str, Value)> {
4498 vec![
4499 ("id1", Value::BigInt(self.id1)),
4500 ("id2", Value::BigInt(self.id2)),
4501 ]
4502 }
4503
4504 fn from_row(_row: &Row) -> sqlmodel_core::Result<Self> {
4505 Ok(Self { id1: 0, id2: 0 })
4506 }
4507
4508 fn primary_key_value(&self) -> Vec<Value> {
4509 vec![Value::BigInt(self.id1), Value::BigInt(self.id2)]
4510 }
4511
4512 fn is_new(&self) -> bool {
4513 false
4514 }
4515 }
4516
4517 #[derive(Debug, Clone, Serialize, Deserialize)]
4518 struct MmParentComposite {
4519 id1: i64,
4520 id2: i64,
4521 children: sqlmodel_core::RelatedMany<MmChildComposite>,
4522 }
4523
4524 impl Model for MmParentComposite {
4525 const TABLE_NAME: &'static str = "mm_parents";
4526 const PRIMARY_KEY: &'static [&'static str] = &["id1", "id2"];
4527 const RELATIONSHIPS: &'static [sqlmodel_core::RelationshipInfo] =
4528 &[sqlmodel_core::RelationshipInfo::new(
4529 "children",
4530 MmChildComposite::TABLE_NAME,
4531 sqlmodel_core::RelationshipKind::ManyToMany,
4532 )
4533 .link_table(sqlmodel_core::LinkTableInfo::composite(
4534 "mm_link",
4535 &["parent_id1", "parent_id2"],
4536 &["child_id1", "child_id2"],
4537 ))
4538 .cascade_delete(true)];
4539
4540 fn fields() -> &'static [sqlmodel_core::FieldInfo] {
4541 &[]
4542 }
4543
4544 fn to_row(&self) -> Vec<(&'static str, Value)> {
4545 vec![
4546 ("id1", Value::BigInt(self.id1)),
4547 ("id2", Value::BigInt(self.id2)),
4548 ]
4549 }
4550
4551 fn from_row(_row: &Row) -> sqlmodel_core::Result<Self> {
4552 Ok(Self {
4553 id1: 0,
4554 id2: 0,
4555 children: sqlmodel_core::RelatedMany::with_link_table(
4556 sqlmodel_core::LinkTableInfo::composite(
4557 "mm_link",
4558 &["parent_id1", "parent_id2"],
4559 &["child_id1", "child_id2"],
4560 ),
4561 ),
4562 })
4563 }
4564
4565 fn primary_key_value(&self) -> Vec<Value> {
4566 vec![Value::BigInt(self.id1), Value::BigInt(self.id2)]
4567 }
4568
4569 fn is_new(&self) -> bool {
4570 false
4571 }
4572 }
4573
4574 #[test]
4575 fn test_flush_cascade_delete_many_to_many_composite_parent_keys_deletes_link_rows_first() {
4576 let rt = RuntimeBuilder::current_thread()
4577 .build()
4578 .expect("create asupersync runtime");
4579 let cx = Cx::for_testing();
4580
4581 let state = Arc::new(Mutex::new(MockState::default()));
4582 let conn = MockConnection::new(Arc::clone(&state));
4583 let mut session = Session::with_config(
4584 conn,
4585 SessionConfig {
4586 auto_begin: false,
4587 auto_flush: false,
4588 expire_on_commit: true,
4589 },
4590 );
4591
4592 let parent = MmParentComposite {
4593 id1: 1,
4594 id2: 2,
4595 children: sqlmodel_core::RelatedMany::with_link_table(
4596 sqlmodel_core::LinkTableInfo::composite(
4597 "mm_link",
4598 &["parent_id1", "parent_id2"],
4599 &["child_id1", "child_id2"],
4600 ),
4601 ),
4602 };
4603 let key = ObjectKey::from_model(&parent);
4604
4605 session.identity_map.insert(
4606 key,
4607 TrackedObject {
4608 object: Box::new(parent.clone()),
4609 original_state: None,
4610 state: ObjectState::Persistent,
4611 table_name: MmParentComposite::TABLE_NAME,
4612 column_names: vec!["id1", "id2"],
4613 values: vec![Value::BigInt(1), Value::BigInt(2)],
4614 pk_columns: vec!["id1", "id2"],
4615 pk_values: vec![Value::BigInt(1), Value::BigInt(2)],
4616 relationships: MmParentComposite::RELATIONSHIPS,
4617 expired_attributes: None,
4618 },
4619 );
4620
4621 rt.block_on(async {
4622 session.delete(&parent);
4623 unwrap_outcome(session.flush(&cx).await);
4624 assert_eq!(session.tracked_count(), 0);
4625 });
4626
4627 let guard = state.lock().expect("lock poisoned");
4628 assert!(
4629 guard.execute_calls >= 2,
4630 "expected at least link-table cascade + parent delete"
4631 );
4632 let (sql0, _params0) = &guard.executed[0];
4633 let (sql1, _params1) = &guard.executed[1];
4634 assert!(
4635 sql0.contains("DELETE") && sql0.contains("mm_link"),
4636 "expected first delete to target link table"
4637 );
4638 assert!(
4639 sql0.contains("parent_id1") && sql0.contains("parent_id2"),
4640 "expected composite local cols in link delete"
4641 );
4642 assert!(
4643 sql1.contains("DELETE") && sql1.contains("mm_parents"),
4644 "expected second delete to target parent table"
4645 );
4646 }
4647
4648 #[test]
4649 fn test_flush_related_many_composite_link_and_unlink() {
4650 let rt = RuntimeBuilder::current_thread()
4651 .build()
4652 .expect("create asupersync runtime");
4653 let cx = Cx::for_testing();
4654
4655 let state = Arc::new(Mutex::new(MockState::default()));
4656 let conn = MockConnection::new(Arc::clone(&state));
4657 let mut session = Session::with_config(
4658 conn,
4659 SessionConfig {
4660 auto_begin: false,
4661 auto_flush: false,
4662 expire_on_commit: true,
4663 },
4664 );
4665
4666 let link = sqlmodel_core::LinkTableInfo::composite(
4667 "mm_link",
4668 &["parent_id1", "parent_id2"],
4669 &["child_id1", "child_id2"],
4670 );
4671
4672 let mut parents = vec![MmParentComposite {
4673 id1: 1,
4674 id2: 2,
4675 children: sqlmodel_core::RelatedMany::with_link_table(link),
4676 }];
4677
4678 let child = MmChildComposite { id1: 7, id2: 9 };
4679
4680 parents[0].children.link(&child);
4681 parents[0].children.unlink(&child);
4682
4683 rt.block_on(async {
4684 let n = unwrap_outcome(
4685 session
4686 .flush_related_many_pk::<MmParentComposite, MmChildComposite, _, _>(
4687 &cx,
4688 &mut parents,
4689 |p| &mut p.children,
4690 |p| vec![Value::BigInt(p.id1), Value::BigInt(p.id2)],
4691 &link,
4692 )
4693 .await,
4694 );
4695 assert_eq!(n, 2);
4696 });
4697
4698 let guard = state.lock().expect("lock poisoned");
4699 assert_eq!(guard.execute_calls, 2);
4700 let (sql0, _params0) = &guard.executed[0];
4701 let (sql1, _params1) = &guard.executed[1];
4702
4703 assert!(sql0.contains("INSERT INTO"));
4704 assert!(sql0.contains("mm_link"));
4705 assert!(sql0.contains("parent_id1"));
4706 assert!(sql0.contains("parent_id2"));
4707 assert!(sql0.contains("child_id1"));
4708 assert!(sql0.contains("child_id2"));
4709 assert!(sql0.contains("$1") && sql0.contains("$4"));
4710
4711 assert!(sql1.contains("DELETE FROM"));
4712 assert!(sql1.contains("mm_link"));
4713 assert!(sql1.contains("parent_id1"));
4714 assert!(sql1.contains("child_id2"));
4715 assert!(sql1.contains("$1") && sql1.contains("$4"));
4716 }
4717
4718 #[test]
4719 fn test_load_many_to_many_pk_composite_builds_tuple_where_clause() {
4720 let rt = RuntimeBuilder::current_thread()
4721 .build()
4722 .expect("create asupersync runtime");
4723 let cx = Cx::for_testing();
4724
4725 let state = Arc::new(Mutex::new(MockState::default()));
4726 let conn = MockConnection::new(Arc::clone(&state));
4727 let mut session = Session::new(conn);
4728
4729 let link = sqlmodel_core::LinkTableInfo::composite(
4730 "mm_link",
4731 &["parent_id1", "parent_id2"],
4732 &["child_id1", "child_id2"],
4733 );
4734
4735 let mut parents = vec![MmParentComposite {
4736 id1: 1,
4737 id2: 2,
4738 children: sqlmodel_core::RelatedMany::with_link_table(link),
4739 }];
4740
4741 rt.block_on(async {
4742 let loaded = unwrap_outcome(
4743 session
4744 .load_many_to_many_pk::<MmParentComposite, MmChildComposite, _, _>(
4745 &cx,
4746 &mut parents,
4747 |p| &mut p.children,
4748 |p| vec![Value::BigInt(p.id1), Value::BigInt(p.id2)],
4749 &link,
4750 )
4751 .await,
4752 );
4753 assert_eq!(loaded, 0);
4754 });
4755
4756 let guard = state.lock().expect("lock poisoned");
4757 assert_eq!(guard.query_calls, 1);
4758 let sql = guard.last_sql.clone().expect("sql captured");
4759 assert!(sql.contains("JOIN"));
4760 assert!(sql.contains("mm_link"));
4761 assert!(sql.contains("WHERE"));
4762 assert!(sql.contains("parent_id1") && sql.contains("parent_id2"));
4763 assert!(sql.contains("IN (("), "expected tuple IN clause");
4764 }
4765
4766 #[test]
4767 fn test_add_all_with_vec() {
4768 let state = Arc::new(Mutex::new(MockState::default()));
4769 let conn = MockConnection::new(Arc::clone(&state));
4770 let mut session = Session::new(conn);
4771
4772 let teams = vec![
4775 Team {
4776 id: Some(100),
4777 name: "Team A".to_string(),
4778 },
4779 Team {
4780 id: Some(101),
4781 name: "Team B".to_string(),
4782 },
4783 Team {
4784 id: Some(102),
4785 name: "Team C".to_string(),
4786 },
4787 ];
4788
4789 session.add_all(&teams);
4790
4791 let info = session.debug_state();
4792 assert_eq!(info.pending_new, 3);
4793 assert_eq!(info.tracked, 3);
4794 }
4795
4796 #[test]
4797 fn test_add_all_with_empty_collection() {
4798 let state = Arc::new(Mutex::new(MockState::default()));
4799 let conn = MockConnection::new(Arc::clone(&state));
4800 let mut session = Session::new(conn);
4801
4802 let teams: Vec<Team> = vec![];
4803 session.add_all(&teams);
4804
4805 let info = session.debug_state();
4806 assert_eq!(info.pending_new, 0);
4807 assert_eq!(info.tracked, 0);
4808 }
4809
4810 #[test]
4811 fn test_add_all_with_iterator() {
4812 let state = Arc::new(Mutex::new(MockState::default()));
4813 let conn = MockConnection::new(Arc::clone(&state));
4814 let mut session = Session::new(conn);
4815
4816 let teams = [
4817 Team {
4818 id: Some(200),
4819 name: "Team X".to_string(),
4820 },
4821 Team {
4822 id: Some(201),
4823 name: "Team Y".to_string(),
4824 },
4825 ];
4826
4827 session.add_all(teams.iter());
4829
4830 let info = session.debug_state();
4831 assert_eq!(info.pending_new, 2);
4832 assert_eq!(info.tracked, 2);
4833 }
4834
4835 #[test]
4836 fn test_add_all_with_slice() {
4837 let state = Arc::new(Mutex::new(MockState::default()));
4838 let conn = MockConnection::new(Arc::clone(&state));
4839 let mut session = Session::new(conn);
4840
4841 let teams = [
4842 Team {
4843 id: Some(300),
4844 name: "Team 1".to_string(),
4845 },
4846 Team {
4847 id: Some(301),
4848 name: "Team 2".to_string(),
4849 },
4850 ];
4851
4852 session.add_all(&teams);
4853
4854 let info = session.debug_state();
4855 assert_eq!(info.pending_new, 2);
4856 assert_eq!(info.tracked, 2);
4857 }
4858
4859 #[test]
4862 fn test_merge_new_object_without_load() {
4863 let rt = RuntimeBuilder::current_thread()
4864 .build()
4865 .expect("create asupersync runtime");
4866 let cx = Cx::for_testing();
4867
4868 let state = Arc::new(Mutex::new(MockState::default()));
4869 let conn = MockConnection::new(Arc::clone(&state));
4870 let mut session = Session::new(conn);
4871
4872 rt.block_on(async {
4873 let team = Team {
4875 id: Some(100),
4876 name: "New Team".to_string(),
4877 };
4878
4879 let merged = unwrap_outcome(session.merge(&cx, team.clone(), false).await);
4880
4881 assert_eq!(merged.id, Some(100));
4883 assert_eq!(merged.name, "New Team");
4884
4885 let info = session.debug_state();
4887 assert_eq!(info.pending_new, 1);
4888 assert_eq!(info.tracked, 1);
4889 });
4890
4891 assert_eq!(state.lock().expect("lock poisoned").query_calls, 0);
4893 }
4894
4895 #[test]
4896 fn test_merge_updates_existing_tracked_object() {
4897 let rt = RuntimeBuilder::current_thread()
4898 .build()
4899 .expect("create asupersync runtime");
4900 let cx = Cx::for_testing();
4901
4902 let state = Arc::new(Mutex::new(MockState::default()));
4903 let conn = MockConnection::new(Arc::clone(&state));
4904 let mut session = Session::new(conn);
4905
4906 rt.block_on(async {
4907 let original = Team {
4909 id: Some(1),
4910 name: "Original".to_string(),
4911 };
4912 session.add(&original);
4913
4914 let updated = Team {
4916 id: Some(1),
4917 name: "Updated".to_string(),
4918 };
4919
4920 let merged = unwrap_outcome(session.merge(&cx, updated, false).await);
4921
4922 assert_eq!(merged.id, Some(1));
4924 assert_eq!(merged.name, "Updated");
4925
4926 let info = session.debug_state();
4928 assert_eq!(info.tracked, 1);
4929 });
4930 }
4931
4932 #[test]
4933 fn test_merge_with_load_queries_database() {
4934 let rt = RuntimeBuilder::current_thread()
4935 .build()
4936 .expect("create asupersync runtime");
4937 let cx = Cx::for_testing();
4938
4939 let state = Arc::new(Mutex::new(MockState::default()));
4940 let conn = MockConnection::new(Arc::clone(&state));
4941 let mut session = Session::new(conn);
4942
4943 rt.block_on(async {
4944 let detached = Team {
4946 id: Some(1),
4947 name: "Detached Update".to_string(),
4948 };
4949
4950 let merged = unwrap_outcome(session.merge(&cx, detached, true).await);
4951
4952 assert_eq!(merged.id, Some(1));
4954 assert_eq!(merged.name, "Detached Update");
4955
4956 let info = session.debug_state();
4958 assert_eq!(info.tracked, 1);
4959 assert_eq!(info.pending_dirty, 1);
4960 });
4961
4962 assert_eq!(state.lock().expect("lock poisoned").query_calls, 1);
4964 }
4965
4966 #[test]
4967 fn test_merge_with_load_not_found_creates_new() {
4968 let rt = RuntimeBuilder::current_thread()
4969 .build()
4970 .expect("create asupersync runtime");
4971 let cx = Cx::for_testing();
4972
4973 let state = Arc::new(Mutex::new(MockState::default()));
4974 let conn = MockConnection::new(Arc::clone(&state));
4975 let mut session = Session::new(conn);
4976
4977 rt.block_on(async {
4978 let detached = Team {
4980 id: Some(999),
4981 name: "Not In DB".to_string(),
4982 };
4983
4984 let merged = unwrap_outcome(session.merge(&cx, detached, true).await);
4985
4986 assert_eq!(merged.id, Some(999));
4988 assert_eq!(merged.name, "Not In DB");
4989
4990 let info = session.debug_state();
4992 assert_eq!(info.pending_new, 1);
4993 assert_eq!(info.tracked, 1);
4994 });
4995
4996 assert_eq!(state.lock().expect("lock poisoned").query_calls, 1);
4998 }
4999
5000 #[test]
5001 fn test_merge_without_load_convenience() {
5002 let rt = RuntimeBuilder::current_thread()
5003 .build()
5004 .expect("create asupersync runtime");
5005 let cx = Cx::for_testing();
5006
5007 let state = Arc::new(Mutex::new(MockState::default()));
5008 let conn = MockConnection::new(Arc::clone(&state));
5009 let mut session = Session::new(conn);
5010
5011 rt.block_on(async {
5012 let team = Team {
5013 id: Some(42),
5014 name: "Test".to_string(),
5015 };
5016
5017 let merged = unwrap_outcome(session.merge_without_load(&cx, team).await);
5019
5020 assert_eq!(merged.id, Some(42));
5021 assert_eq!(merged.name, "Test");
5022
5023 let info = session.debug_state();
5024 assert_eq!(info.pending_new, 1);
5025 });
5026
5027 assert_eq!(state.lock().expect("lock poisoned").query_calls, 0);
5029 }
5030
5031 #[test]
5032 fn test_merge_null_pk_treated_as_new() {
5033 let rt = RuntimeBuilder::current_thread()
5034 .build()
5035 .expect("create asupersync runtime");
5036 let cx = Cx::for_testing();
5037
5038 let state = Arc::new(Mutex::new(MockState::default()));
5039 let conn = MockConnection::new(Arc::clone(&state));
5040 let mut session = Session::new(conn);
5041
5042 rt.block_on(async {
5043 let new_team = Team {
5045 id: None,
5046 name: "Brand New".to_string(),
5047 };
5048
5049 let merged = unwrap_outcome(session.merge(&cx, new_team, true).await);
5050
5051 assert_eq!(merged.id, None);
5053 assert_eq!(merged.name, "Brand New");
5054
5055 let info = session.debug_state();
5057 assert_eq!(info.pending_new, 1);
5058 });
5059
5060 assert_eq!(state.lock().expect("lock poisoned").query_calls, 0);
5062 }
5063
5064 #[test]
5067 fn test_is_modified_new_object_returns_true() {
5068 let state = Arc::new(Mutex::new(MockState::default()));
5069 let conn = MockConnection::new(Arc::clone(&state));
5070 let mut session = Session::new(conn);
5071
5072 let team = Team {
5073 id: Some(100),
5074 name: "New Team".to_string(),
5075 };
5076
5077 session.add(&team);
5079 assert!(session.is_modified(&team));
5080 }
5081
5082 #[test]
5083 fn test_is_modified_untracked_returns_false() {
5084 let state = Arc::new(Mutex::new(MockState::default()));
5085 let conn = MockConnection::new(Arc::clone(&state));
5086 let session = Session::<MockConnection>::new(conn);
5087
5088 let team = Team {
5089 id: Some(100),
5090 name: "Not Tracked".to_string(),
5091 };
5092
5093 assert!(!session.is_modified(&team));
5095 }
5096
5097 #[test]
5098 fn test_is_modified_after_load_returns_false() {
5099 let rt = RuntimeBuilder::current_thread()
5100 .build()
5101 .expect("create asupersync runtime");
5102 let cx = Cx::for_testing();
5103
5104 let state = Arc::new(Mutex::new(MockState::default()));
5105 let conn = MockConnection::new(Arc::clone(&state));
5106 let mut session = Session::new(conn);
5107
5108 rt.block_on(async {
5109 let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
5111
5112 assert!(!session.is_modified(&team));
5114 });
5115 }
5116
5117 #[test]
5118 fn test_is_modified_after_mark_dirty_returns_true() {
5119 let rt = RuntimeBuilder::current_thread()
5120 .build()
5121 .expect("create asupersync runtime");
5122 let cx = Cx::for_testing();
5123
5124 let state = Arc::new(Mutex::new(MockState::default()));
5125 let conn = MockConnection::new(Arc::clone(&state));
5126 let mut session = Session::new(conn);
5127
5128 rt.block_on(async {
5129 let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
5131 assert!(!session.is_modified(&team));
5132
5133 let mut modified_team = team.clone();
5135 modified_team.name = "Modified Name".to_string();
5136 session.mark_dirty(&modified_team);
5137
5138 assert!(session.is_modified(&modified_team));
5140 });
5141 }
5142
5143 #[test]
5144 fn test_is_modified_deleted_returns_true() {
5145 let rt = RuntimeBuilder::current_thread()
5146 .build()
5147 .expect("create asupersync runtime");
5148 let cx = Cx::for_testing();
5149
5150 let state = Arc::new(Mutex::new(MockState::default()));
5151 let conn = MockConnection::new(Arc::clone(&state));
5152 let mut session = Session::new(conn);
5153
5154 rt.block_on(async {
5155 let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
5157 assert!(!session.is_modified(&team));
5158
5159 session.delete(&team);
5161
5162 assert!(session.is_modified(&team));
5164 });
5165 }
5166
5167 #[test]
5168 fn test_is_modified_detached_returns_false() {
5169 let rt = RuntimeBuilder::current_thread()
5170 .build()
5171 .expect("create asupersync runtime");
5172 let cx = Cx::for_testing();
5173
5174 let state = Arc::new(Mutex::new(MockState::default()));
5175 let conn = MockConnection::new(Arc::clone(&state));
5176 let mut session = Session::new(conn);
5177
5178 rt.block_on(async {
5179 let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
5181
5182 session.expunge(&team);
5184
5185 assert!(!session.is_modified(&team));
5187 });
5188 }
5189
5190 #[test]
5191 fn test_object_state_returns_correct_state() {
5192 let rt = RuntimeBuilder::current_thread()
5193 .build()
5194 .expect("create asupersync runtime");
5195 let cx = Cx::for_testing();
5196
5197 let state = Arc::new(Mutex::new(MockState::default()));
5198 let conn = MockConnection::new(Arc::clone(&state));
5199 let mut session = Session::new(conn);
5200
5201 let untracked = Team {
5203 id: Some(999),
5204 name: "Untracked".to_string(),
5205 };
5206 assert_eq!(session.object_state(&untracked), None);
5207
5208 let new_team = Team {
5210 id: Some(100),
5211 name: "New".to_string(),
5212 };
5213 session.add(&new_team);
5214 assert_eq!(session.object_state(&new_team), Some(ObjectState::New));
5215
5216 rt.block_on(async {
5217 let persistent = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
5219 assert_eq!(
5220 session.object_state(&persistent),
5221 Some(ObjectState::Persistent)
5222 );
5223
5224 session.delete(&persistent);
5226 assert_eq!(
5227 session.object_state(&persistent),
5228 Some(ObjectState::Deleted)
5229 );
5230 });
5231 }
5232
5233 #[test]
5234 fn test_modified_attributes_returns_changed_columns() {
5235 let rt = RuntimeBuilder::current_thread()
5236 .build()
5237 .expect("create asupersync runtime");
5238 let cx = Cx::for_testing();
5239
5240 let state = Arc::new(Mutex::new(MockState::default()));
5241 let conn = MockConnection::new(Arc::clone(&state));
5242 let mut session = Session::new(conn);
5243
5244 rt.block_on(async {
5245 let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
5247
5248 let modified = session.modified_attributes(&team);
5250 assert!(modified.is_empty());
5251
5252 let mut modified_team = team.clone();
5254 modified_team.name = "Changed Name".to_string();
5255 session.mark_dirty(&modified_team);
5256
5257 let modified = session.modified_attributes(&modified_team);
5259 assert!(modified.contains(&"name"));
5260 });
5261 }
5262
5263 #[test]
5264 fn test_modified_attributes_untracked_returns_empty() {
5265 let state = Arc::new(Mutex::new(MockState::default()));
5266 let conn = MockConnection::new(Arc::clone(&state));
5267 let session = Session::<MockConnection>::new(conn);
5268
5269 let team = Team {
5270 id: Some(100),
5271 name: "Not Tracked".to_string(),
5272 };
5273
5274 let modified = session.modified_attributes(&team);
5275 assert!(modified.is_empty());
5276 }
5277
5278 #[test]
5279 fn test_modified_attributes_new_returns_empty() {
5280 let state = Arc::new(Mutex::new(MockState::default()));
5281 let conn = MockConnection::new(Arc::clone(&state));
5282 let mut session = Session::new(conn);
5283
5284 let team = Team {
5285 id: Some(100),
5286 name: "New".to_string(),
5287 };
5288 session.add(&team);
5289
5290 let modified = session.modified_attributes(&team);
5292 assert!(modified.is_empty());
5293 }
5294
5295 #[test]
5298 fn test_expire_marks_object_as_expired() {
5299 let rt = RuntimeBuilder::current_thread()
5300 .build()
5301 .expect("create asupersync runtime");
5302 let cx = Cx::for_testing();
5303
5304 let state = Arc::new(Mutex::new(MockState::default()));
5305 let conn = MockConnection::new(Arc::clone(&state));
5306 let mut session = Session::new(conn);
5307
5308 rt.block_on(async {
5309 let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await);
5311 assert!(team.is_some());
5312 let team = team.unwrap();
5313
5314 assert!(!session.is_expired(&team));
5316 assert_eq!(session.object_state(&team), Some(ObjectState::Persistent));
5317
5318 session.expire(&team, None);
5320
5321 assert!(session.is_expired(&team));
5323 assert_eq!(session.object_state(&team), Some(ObjectState::Expired));
5324 });
5325 }
5326
5327 #[test]
5328 fn test_expire_specific_attributes() {
5329 let rt = RuntimeBuilder::current_thread()
5330 .build()
5331 .expect("create asupersync runtime");
5332 let cx = Cx::for_testing();
5333
5334 let state = Arc::new(Mutex::new(MockState::default()));
5335 let conn = MockConnection::new(Arc::clone(&state));
5336 let mut session = Session::new(conn);
5337
5338 rt.block_on(async {
5339 let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
5341
5342 session.expire(&team, Some(&["name"]));
5344
5345 assert!(session.is_expired(&team));
5347
5348 let expired = session.expired_attributes(&team);
5350 assert!(expired.is_some());
5351 let expired_set = expired.unwrap();
5352 assert!(expired_set.is_some());
5353 assert!(expired_set.unwrap().contains("name"));
5354 });
5355 }
5356
5357 #[test]
5358 fn test_expire_all_marks_all_objects_expired() {
5359 let rt = RuntimeBuilder::current_thread()
5360 .build()
5361 .expect("create asupersync runtime");
5362 let cx = Cx::for_testing();
5363
5364 let state = Arc::new(Mutex::new(MockState::default()));
5365 let conn = MockConnection::new(Arc::clone(&state));
5366 let mut session = Session::new(conn);
5367
5368 rt.block_on(async {
5369 let team1 = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
5371 let team2 = unwrap_outcome(session.get::<Team>(&cx, 2_i64).await).unwrap();
5372
5373 assert!(!session.is_expired(&team1));
5375 assert!(!session.is_expired(&team2));
5376
5377 session.expire_all();
5379
5380 assert!(session.is_expired(&team1));
5382 assert!(session.is_expired(&team2));
5383 });
5384 }
5385
5386 #[test]
5387 fn test_expire_does_not_affect_new_objects() {
5388 let state = Arc::new(Mutex::new(MockState::default()));
5389 let conn = MockConnection::new(Arc::clone(&state));
5390 let mut session = Session::new(conn);
5391
5392 let team = Team {
5394 id: Some(100),
5395 name: "New Team".to_string(),
5396 };
5397 session.add(&team);
5398
5399 session.expire(&team, None);
5401
5402 assert_eq!(session.object_state(&team), Some(ObjectState::New));
5404 assert!(!session.is_expired(&team));
5405 }
5406
5407 #[test]
5408 fn test_expired_object_reloads_on_get() {
5409 let rt = RuntimeBuilder::current_thread()
5410 .build()
5411 .expect("create asupersync runtime");
5412 let cx = Cx::for_testing();
5413
5414 let state = Arc::new(Mutex::new(MockState::default()));
5415 let conn = MockConnection::new(Arc::clone(&state));
5416 let mut session = Session::new(conn);
5417
5418 rt.block_on(async {
5419 let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
5421 assert_eq!(team.name, "Avengers");
5422
5423 let team2 = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
5425 assert_eq!(team2.name, "Avengers");
5426
5427 {
5429 let s = state.lock().expect("lock poisoned");
5430 assert_eq!(s.query_calls, 1);
5431 }
5432
5433 session.expire(&team, None);
5435
5436 let team3 = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
5438 assert_eq!(team3.name, "Avengers");
5439
5440 {
5442 let s = state.lock().expect("lock poisoned");
5443 assert_eq!(s.query_calls, 2);
5444 }
5445
5446 assert!(!session.is_expired(&team3));
5448 assert_eq!(session.object_state(&team3), Some(ObjectState::Persistent));
5449 });
5450 }
5451
5452 #[test]
5453 fn test_is_expired_returns_false_for_untracked() {
5454 let state = Arc::new(Mutex::new(MockState::default()));
5455 let conn = MockConnection::new(Arc::clone(&state));
5456 let session = Session::<MockConnection>::new(conn);
5457
5458 let team = Team {
5459 id: Some(999),
5460 name: "Not Tracked".to_string(),
5461 };
5462
5463 assert!(!session.is_expired(&team));
5465 }
5466
5467 #[test]
5468 fn test_expired_attributes_returns_none_for_persistent() {
5469 let rt = RuntimeBuilder::current_thread()
5470 .build()
5471 .expect("create asupersync runtime");
5472 let cx = Cx::for_testing();
5473
5474 let state = Arc::new(Mutex::new(MockState::default()));
5475 let conn = MockConnection::new(Arc::clone(&state));
5476 let mut session = Session::new(conn);
5477
5478 rt.block_on(async {
5479 let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
5481
5482 let expired = session.expired_attributes(&team);
5484 assert!(expired.is_none());
5485 });
5486 }
5487}