1pub mod change_tracker;
43pub mod flush;
44pub mod identity_map;
45pub mod n1_detection;
46pub mod unit_of_work;
47
48pub use change_tracker::{ChangeTracker, ObjectSnapshot};
49pub use flush::{
50 FlushOrderer, FlushPlan, FlushResult, LinkTableOp, PendingOp, execute_link_table_ops,
51};
52pub use identity_map::{IdentityMap, ModelReadGuard, ModelRef, ModelWriteGuard, WeakIdentityMap};
53pub use n1_detection::{CallSite, N1DetectionScope, N1QueryTracker, N1Stats};
54pub use unit_of_work::{PendingCounts, UnitOfWork, UowError};
55
56use asupersync::{Cx, Outcome};
57use serde::{Deserialize, Serialize};
58use sqlmodel_core::{Connection, Error, Lazy, LazyLoader, Model, Value};
59use std::any::{Any, TypeId};
60use std::collections::HashMap;
61use std::future::Future;
62use std::hash::{Hash, Hasher};
63
64type SessionEventFn = Box<dyn FnMut() -> Result<(), Error> + Send>;
73
74#[derive(Default)]
79pub struct SessionEventCallbacks {
80 before_flush: Vec<SessionEventFn>,
81 after_flush: Vec<SessionEventFn>,
82 before_commit: Vec<SessionEventFn>,
83 after_commit: Vec<SessionEventFn>,
84 after_rollback: Vec<SessionEventFn>,
85}
86
87impl std::fmt::Debug for SessionEventCallbacks {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.debug_struct("SessionEventCallbacks")
90 .field("before_flush", &self.before_flush.len())
91 .field("after_flush", &self.after_flush.len())
92 .field("before_commit", &self.before_commit.len())
93 .field("after_commit", &self.after_commit.len())
94 .field("after_rollback", &self.after_rollback.len())
95 .finish()
96 }
97}
98
99impl SessionEventCallbacks {
100 #[allow(clippy::result_large_err)]
101 fn fire(&mut self, event: SessionEvent) -> Result<(), Error> {
102 let callbacks = match event {
103 SessionEvent::BeforeFlush => &mut self.before_flush,
104 SessionEvent::AfterFlush => &mut self.after_flush,
105 SessionEvent::BeforeCommit => &mut self.before_commit,
106 SessionEvent::AfterCommit => &mut self.after_commit,
107 SessionEvent::AfterRollback => &mut self.after_rollback,
108 };
109 for cb in callbacks.iter_mut() {
110 cb()?;
111 }
112 Ok(())
113 }
114}
115
116#[derive(Debug, Clone, Copy, PartialEq, Eq)]
118pub enum SessionEvent {
119 BeforeFlush,
121 AfterFlush,
123 BeforeCommit,
125 AfterCommit,
127 AfterRollback,
129}
130
131#[derive(Debug, Clone)]
137pub struct SessionConfig {
138 pub auto_begin: bool,
140 pub auto_flush: bool,
142 pub expire_on_commit: bool,
144}
145
146impl Default for SessionConfig {
147 fn default() -> Self {
148 Self {
149 auto_begin: true,
150 auto_flush: false,
151 expire_on_commit: true,
152 }
153 }
154}
155
156#[derive(Debug, Clone, Default)]
158pub struct GetOptions {
159 pub with_for_update: bool,
161 pub skip_locked: bool,
163 pub nowait: bool,
165}
166
167impl GetOptions {
168 #[must_use]
170 pub fn new() -> Self {
171 Self::default()
172 }
173
174 #[must_use]
176 pub fn with_for_update(mut self, value: bool) -> Self {
177 self.with_for_update = value;
178 self
179 }
180
181 #[must_use]
183 pub fn skip_locked(mut self, value: bool) -> Self {
184 self.skip_locked = value;
185 self
186 }
187
188 #[must_use]
190 pub fn nowait(mut self, value: bool) -> Self {
191 self.nowait = value;
192 self
193 }
194}
195
196#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
202pub struct ObjectKey {
203 type_id: TypeId,
205 pk_hash: u64,
207}
208
209impl ObjectKey {
210 pub fn from_model<M: Model + 'static>(obj: &M) -> Self {
212 let pk_values = obj.primary_key_value();
213 Self {
214 type_id: TypeId::of::<M>(),
215 pk_hash: hash_values(&pk_values),
216 }
217 }
218
219 pub fn from_pk<M: Model + 'static>(pk: &[Value]) -> Self {
221 Self {
222 type_id: TypeId::of::<M>(),
223 pk_hash: hash_values(pk),
224 }
225 }
226
227 pub fn pk_hash(&self) -> u64 {
229 self.pk_hash
230 }
231
232 pub fn type_id(&self) -> TypeId {
234 self.type_id
235 }
236}
237
238fn hash_values(values: &[Value]) -> u64 {
240 use std::collections::hash_map::DefaultHasher;
241 let mut hasher = DefaultHasher::new();
242 for v in values {
243 match v {
245 Value::Null => 0u8.hash(&mut hasher),
246 Value::Bool(b) => {
247 1u8.hash(&mut hasher);
248 b.hash(&mut hasher);
249 }
250 Value::TinyInt(i) => {
251 2u8.hash(&mut hasher);
252 i.hash(&mut hasher);
253 }
254 Value::SmallInt(i) => {
255 3u8.hash(&mut hasher);
256 i.hash(&mut hasher);
257 }
258 Value::Int(i) => {
259 4u8.hash(&mut hasher);
260 i.hash(&mut hasher);
261 }
262 Value::BigInt(i) => {
263 5u8.hash(&mut hasher);
264 i.hash(&mut hasher);
265 }
266 Value::Float(f) => {
267 6u8.hash(&mut hasher);
268 f.to_bits().hash(&mut hasher);
269 }
270 Value::Double(f) => {
271 7u8.hash(&mut hasher);
272 f.to_bits().hash(&mut hasher);
273 }
274 Value::Decimal(s) => {
275 8u8.hash(&mut hasher);
276 s.hash(&mut hasher);
277 }
278 Value::Text(s) => {
279 9u8.hash(&mut hasher);
280 s.hash(&mut hasher);
281 }
282 Value::Bytes(b) => {
283 10u8.hash(&mut hasher);
284 b.hash(&mut hasher);
285 }
286 Value::Date(d) => {
287 11u8.hash(&mut hasher);
288 d.hash(&mut hasher);
289 }
290 Value::Time(t) => {
291 12u8.hash(&mut hasher);
292 t.hash(&mut hasher);
293 }
294 Value::Timestamp(ts) => {
295 13u8.hash(&mut hasher);
296 ts.hash(&mut hasher);
297 }
298 Value::TimestampTz(ts) => {
299 14u8.hash(&mut hasher);
300 ts.hash(&mut hasher);
301 }
302 Value::Uuid(u) => {
303 15u8.hash(&mut hasher);
304 u.hash(&mut hasher);
305 }
306 Value::Json(j) => {
307 16u8.hash(&mut hasher);
308 j.to_string().hash(&mut hasher);
310 }
311 Value::Array(arr) => {
312 17u8.hash(&mut hasher);
313 arr.len().hash(&mut hasher);
315 for item in arr {
316 hash_value(item, &mut hasher);
317 }
318 }
319 Value::Default => {
320 18u8.hash(&mut hasher);
321 }
322 }
323 }
324 hasher.finish()
325}
326
327fn hash_value(v: &Value, hasher: &mut impl Hasher) {
329 match v {
330 Value::Null => 0u8.hash(hasher),
331 Value::Bool(b) => {
332 1u8.hash(hasher);
333 b.hash(hasher);
334 }
335 Value::TinyInt(i) => {
336 2u8.hash(hasher);
337 i.hash(hasher);
338 }
339 Value::SmallInt(i) => {
340 3u8.hash(hasher);
341 i.hash(hasher);
342 }
343 Value::Int(i) => {
344 4u8.hash(hasher);
345 i.hash(hasher);
346 }
347 Value::BigInt(i) => {
348 5u8.hash(hasher);
349 i.hash(hasher);
350 }
351 Value::Float(f) => {
352 6u8.hash(hasher);
353 f.to_bits().hash(hasher);
354 }
355 Value::Double(f) => {
356 7u8.hash(hasher);
357 f.to_bits().hash(hasher);
358 }
359 Value::Decimal(s) => {
360 8u8.hash(hasher);
361 s.hash(hasher);
362 }
363 Value::Text(s) => {
364 9u8.hash(hasher);
365 s.hash(hasher);
366 }
367 Value::Bytes(b) => {
368 10u8.hash(hasher);
369 b.hash(hasher);
370 }
371 Value::Date(d) => {
372 11u8.hash(hasher);
373 d.hash(hasher);
374 }
375 Value::Time(t) => {
376 12u8.hash(hasher);
377 t.hash(hasher);
378 }
379 Value::Timestamp(ts) => {
380 13u8.hash(hasher);
381 ts.hash(hasher);
382 }
383 Value::TimestampTz(ts) => {
384 14u8.hash(hasher);
385 ts.hash(hasher);
386 }
387 Value::Uuid(u) => {
388 15u8.hash(hasher);
389 u.hash(hasher);
390 }
391 Value::Json(j) => {
392 16u8.hash(hasher);
393 j.to_string().hash(hasher);
394 }
395 Value::Array(arr) => {
396 17u8.hash(hasher);
397 arr.len().hash(hasher);
398 for item in arr {
399 hash_value(item, hasher);
400 }
401 }
402 Value::Default => {
403 18u8.hash(hasher);
404 }
405 }
406}
407
408#[derive(Debug, Clone, Copy, PartialEq, Eq)]
410pub enum ObjectState {
411 New,
413 Persistent,
415 Deleted,
417 Detached,
419 Expired,
421}
422
423struct TrackedObject {
425 object: Box<dyn Any + Send + Sync>,
427 original_state: Option<Vec<u8>>,
429 state: ObjectState,
431 table_name: &'static str,
433 column_names: Vec<&'static str>,
435 values: Vec<Value>,
437 pk_columns: Vec<&'static str>,
439 pk_values: Vec<Value>,
441 expired_attributes: Option<std::collections::HashSet<String>>,
444}
445
446pub struct Session<C: Connection> {
455 connection: C,
457 in_transaction: bool,
459 identity_map: HashMap<ObjectKey, TrackedObject>,
461 pending_new: Vec<ObjectKey>,
463 pending_delete: Vec<ObjectKey>,
465 pending_dirty: Vec<ObjectKey>,
467 config: SessionConfig,
469 n1_tracker: Option<N1QueryTracker>,
471 event_callbacks: SessionEventCallbacks,
473}
474
475impl<C: Connection> Session<C> {
476 pub fn new(connection: C) -> Self {
478 Self::with_config(connection, SessionConfig::default())
479 }
480
481 pub fn with_config(connection: C, config: SessionConfig) -> Self {
483 Self {
484 connection,
485 in_transaction: false,
486 identity_map: HashMap::new(),
487 pending_new: Vec::new(),
488 pending_delete: Vec::new(),
489 pending_dirty: Vec::new(),
490 config,
491 n1_tracker: None,
492 event_callbacks: SessionEventCallbacks::default(),
493 }
494 }
495
496 pub fn connection(&self) -> &C {
498 &self.connection
499 }
500
501 pub fn config(&self) -> &SessionConfig {
503 &self.config
504 }
505
506 pub fn on_before_flush(&mut self, f: impl FnMut() -> Result<(), Error> + Send + 'static) {
514 self.event_callbacks.before_flush.push(Box::new(f));
515 }
516
517 pub fn on_after_flush(&mut self, f: impl FnMut() -> Result<(), Error> + Send + 'static) {
519 self.event_callbacks.after_flush.push(Box::new(f));
520 }
521
522 pub fn on_before_commit(&mut self, f: impl FnMut() -> Result<(), Error> + Send + 'static) {
526 self.event_callbacks.before_commit.push(Box::new(f));
527 }
528
529 pub fn on_after_commit(&mut self, f: impl FnMut() -> Result<(), Error> + Send + 'static) {
531 self.event_callbacks.after_commit.push(Box::new(f));
532 }
533
534 pub fn on_after_rollback(&mut self, f: impl FnMut() -> Result<(), Error> + Send + 'static) {
536 self.event_callbacks.after_rollback.push(Box::new(f));
537 }
538
539 pub fn add<M: Model + Clone + Send + Sync + Serialize + 'static>(&mut self, obj: &M) {
547 let key = ObjectKey::from_model(obj);
548
549 if let Some(tracked) = self.identity_map.get_mut(&key) {
551 tracked.object = Box::new(obj.clone());
552
553 let row_data = obj.to_row();
555 tracked.column_names = row_data.iter().map(|(name, _)| *name).collect();
556 tracked.values = row_data.into_iter().map(|(_, v)| v).collect();
557 tracked.pk_values = obj.primary_key_value();
558
559 if tracked.state == ObjectState::Deleted {
560 self.pending_delete.retain(|k| k != &key);
562
563 if tracked.original_state.is_some() {
564 tracked.state = ObjectState::Persistent;
566 } else {
567 tracked.state = ObjectState::New;
569 if !self.pending_new.contains(&key) {
570 self.pending_new.push(key);
571 }
572 }
573 }
574 return;
575 }
576
577 let row_data = obj.to_row();
579 let column_names: Vec<&'static str> = row_data.iter().map(|(name, _)| *name).collect();
580 let values: Vec<Value> = row_data.into_iter().map(|(_, v)| v).collect();
581
582 let pk_columns: Vec<&'static str> = M::PRIMARY_KEY.to_vec();
584 let pk_values = obj.primary_key_value();
585
586 let tracked = TrackedObject {
587 object: Box::new(obj.clone()),
588 original_state: None, state: ObjectState::New,
590 table_name: M::TABLE_NAME,
591 column_names,
592 values,
593 pk_columns,
594 pk_values,
595 expired_attributes: None,
596 };
597
598 self.identity_map.insert(key, tracked);
599 self.pending_new.push(key);
600 }
601
602 pub fn add_all<'a, M, I>(&mut self, objects: I)
619 where
620 M: Model + Clone + Send + Sync + Serialize + 'static,
621 I: IntoIterator<Item = &'a M>,
622 {
623 for obj in objects {
624 self.add(obj);
625 }
626 }
627
628 pub fn delete<M: Model + 'static>(&mut self, obj: &M) {
632 let key = ObjectKey::from_model(obj);
633
634 if let Some(tracked) = self.identity_map.get_mut(&key) {
635 match tracked.state {
636 ObjectState::New => {
637 self.identity_map.remove(&key);
639 self.pending_new.retain(|k| k != &key);
640 }
641 ObjectState::Persistent | ObjectState::Expired => {
642 tracked.state = ObjectState::Deleted;
643 self.pending_delete.push(key);
644 self.pending_dirty.retain(|k| k != &key);
645 }
646 ObjectState::Deleted | ObjectState::Detached => {
647 }
649 }
650 }
651 }
652
653 pub fn mark_dirty<M: Model + Clone + Send + Sync + Serialize + 'static>(&mut self, obj: &M) {
667 let key = ObjectKey::from_model(obj);
668
669 if let Some(tracked) = self.identity_map.get_mut(&key) {
670 if tracked.state != ObjectState::Persistent {
672 return;
673 }
674
675 tracked.object = Box::new(obj.clone());
677 let row_data = obj.to_row();
678 tracked.column_names = row_data.iter().map(|(name, _)| *name).collect();
679 tracked.values = row_data.into_iter().map(|(_, v)| v).collect();
680 tracked.pk_values = obj.primary_key_value();
681
682 if !self.pending_dirty.contains(&key) {
684 self.pending_dirty.push(key);
685 }
686 }
687 }
688
689 pub async fn get<
693 M: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
694 >(
695 &mut self,
696 cx: &Cx,
697 pk: impl Into<Value>,
698 ) -> Outcome<Option<M>, Error> {
699 let pk_value = pk.into();
700 let pk_values = vec![pk_value.clone()];
701 let key = ObjectKey::from_pk::<M>(&pk_values);
702
703 if let Some(tracked) = self.identity_map.get(&key) {
705 match tracked.state {
706 ObjectState::Deleted | ObjectState::Detached => {
707 }
709 ObjectState::Expired => {
710 tracing::debug!("Object is expired, reloading from database");
712 }
713 ObjectState::New | ObjectState::Persistent => {
714 if let Some(obj) = tracked.object.downcast_ref::<M>() {
715 return Outcome::Ok(Some(obj.clone()));
716 }
717 }
718 }
719 }
720
721 let pk_col = M::PRIMARY_KEY.first().unwrap_or(&"id");
723 let sql = format!(
724 "SELECT * FROM \"{}\" WHERE \"{}\" = $1 LIMIT 1",
725 M::TABLE_NAME,
726 pk_col
727 );
728
729 let rows = match self.connection.query(cx, &sql, &[pk_value]).await {
730 Outcome::Ok(rows) => rows,
731 Outcome::Err(e) => return Outcome::Err(e),
732 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
733 Outcome::Panicked(p) => return Outcome::Panicked(p),
734 };
735
736 if rows.is_empty() {
737 return Outcome::Ok(None);
738 }
739
740 let obj = match M::from_row(&rows[0]) {
742 Ok(obj) => obj,
743 Err(e) => return Outcome::Err(e),
744 };
745
746 let row_data = obj.to_row();
748 let column_names: Vec<&'static str> = row_data.iter().map(|(name, _)| *name).collect();
749 let values: Vec<Value> = row_data.into_iter().map(|(_, v)| v).collect();
750
751 let serialized = serde_json::to_vec(&values).ok();
753
754 let pk_columns: Vec<&'static str> = M::PRIMARY_KEY.to_vec();
756 let obj_pk_values = obj.primary_key_value();
757
758 let tracked = TrackedObject {
759 object: Box::new(obj.clone()),
760 original_state: serialized,
761 state: ObjectState::Persistent,
762 table_name: M::TABLE_NAME,
763 column_names,
764 values,
765 pk_columns,
766 pk_values: obj_pk_values,
767 expired_attributes: None,
768 };
769
770 self.identity_map.insert(key, tracked);
771
772 Outcome::Ok(Some(obj))
773 }
774
775 pub async fn get_by_pk<
789 M: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
790 >(
791 &mut self,
792 cx: &Cx,
793 pk_values: &[Value],
794 ) -> Outcome<Option<M>, Error> {
795 self.get_with_options::<M>(cx, pk_values, &GetOptions::default())
796 .await
797 }
798
799 pub async fn get_with_options<
812 M: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
813 >(
814 &mut self,
815 cx: &Cx,
816 pk_values: &[Value],
817 options: &GetOptions,
818 ) -> Outcome<Option<M>, Error> {
819 let key = ObjectKey::from_pk::<M>(pk_values);
820
821 if !options.with_for_update {
823 if let Some(tracked) = self.identity_map.get(&key) {
824 match tracked.state {
825 ObjectState::Deleted | ObjectState::Detached => {
826 }
828 ObjectState::Expired => {
829 tracing::debug!("Object is expired, reloading from database");
831 }
832 ObjectState::New | ObjectState::Persistent => {
833 if let Some(obj) = tracked.object.downcast_ref::<M>() {
834 return Outcome::Ok(Some(obj.clone()));
835 }
836 }
837 }
838 }
839 }
840
841 let pk_columns = M::PRIMARY_KEY;
843 if pk_columns.len() != pk_values.len() {
844 return Outcome::Err(Error::Custom(format!(
845 "Primary key mismatch: expected {} values, got {}",
846 pk_columns.len(),
847 pk_values.len()
848 )));
849 }
850
851 let where_parts: Vec<String> = pk_columns
852 .iter()
853 .enumerate()
854 .map(|(i, col)| format!("\"{}\" = ${}", col, i + 1))
855 .collect();
856
857 let mut sql = format!(
858 "SELECT * FROM \"{}\" WHERE {} LIMIT 1",
859 M::TABLE_NAME,
860 where_parts.join(" AND ")
861 );
862
863 if options.with_for_update {
865 sql.push_str(" FOR UPDATE");
866 if options.skip_locked {
867 sql.push_str(" SKIP LOCKED");
868 } else if options.nowait {
869 sql.push_str(" NOWAIT");
870 }
871 }
872
873 let rows = match self.connection.query(cx, &sql, pk_values).await {
874 Outcome::Ok(rows) => rows,
875 Outcome::Err(e) => return Outcome::Err(e),
876 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
877 Outcome::Panicked(p) => return Outcome::Panicked(p),
878 };
879
880 if rows.is_empty() {
881 return Outcome::Ok(None);
882 }
883
884 let obj = match M::from_row(&rows[0]) {
886 Ok(obj) => obj,
887 Err(e) => return Outcome::Err(e),
888 };
889
890 let row_data = obj.to_row();
892 let column_names: Vec<&'static str> = row_data.iter().map(|(name, _)| *name).collect();
893 let values: Vec<Value> = row_data.into_iter().map(|(_, v)| v).collect();
894
895 let serialized = serde_json::to_vec(&values).ok();
897
898 let pk_cols: Vec<&'static str> = M::PRIMARY_KEY.to_vec();
900 let obj_pk_values = obj.primary_key_value();
901
902 let tracked = TrackedObject {
903 object: Box::new(obj.clone()),
904 original_state: serialized,
905 state: ObjectState::Persistent,
906 table_name: M::TABLE_NAME,
907 column_names,
908 values,
909 pk_columns: pk_cols,
910 pk_values: obj_pk_values,
911 expired_attributes: None,
912 };
913
914 self.identity_map.insert(key, tracked);
915
916 Outcome::Ok(Some(obj))
917 }
918
919 pub fn contains<M: Model + 'static>(&self, obj: &M) -> bool {
921 let key = ObjectKey::from_model(obj);
922 self.identity_map.contains_key(&key)
923 }
924
925 pub fn expunge<M: Model + 'static>(&mut self, obj: &M) {
927 let key = ObjectKey::from_model(obj);
928 if let Some(tracked) = self.identity_map.get_mut(&key) {
929 tracked.state = ObjectState::Detached;
930 }
931 self.pending_new.retain(|k| k != &key);
932 self.pending_delete.retain(|k| k != &key);
933 self.pending_dirty.retain(|k| k != &key);
934 }
935
936 pub fn expunge_all(&mut self) {
938 for tracked in self.identity_map.values_mut() {
939 tracked.state = ObjectState::Detached;
940 }
941 self.pending_new.clear();
942 self.pending_delete.clear();
943 self.pending_dirty.clear();
944 }
945
946 pub fn is_modified<M: Model + Serialize + 'static>(&self, obj: &M) -> bool {
975 let key = ObjectKey::from_model(obj);
976
977 let Some(tracked) = self.identity_map.get(&key) else {
978 return false;
979 };
980
981 match tracked.state {
982 ObjectState::New => true,
984
985 ObjectState::Deleted => true,
987
988 ObjectState::Detached | ObjectState::Expired => false,
990
991 ObjectState::Persistent => {
993 if self.pending_dirty.contains(&key) {
995 return true;
996 }
997
998 let current_state = serde_json::to_vec(&tracked.values).unwrap_or_default();
1000 tracked.original_state.as_ref() != Some(¤t_state)
1001 }
1002 }
1003 }
1004
1005 pub fn modified_attributes<M: Model + Serialize + 'static>(
1024 &self,
1025 obj: &M,
1026 ) -> Vec<&'static str> {
1027 let key = ObjectKey::from_model(obj);
1028
1029 let Some(tracked) = self.identity_map.get(&key) else {
1030 return Vec::new();
1031 };
1032
1033 if tracked.state != ObjectState::Persistent {
1035 return Vec::new();
1036 }
1037
1038 let Some(original_bytes) = &tracked.original_state else {
1040 return Vec::new();
1041 };
1042
1043 let Ok(original_values): Result<Vec<Value>, _> = serde_json::from_slice(original_bytes)
1045 else {
1046 return Vec::new();
1047 };
1048
1049 let mut modified = Vec::new();
1051 for (i, col) in tracked.column_names.iter().enumerate() {
1052 let current = tracked.values.get(i);
1053 let original = original_values.get(i);
1054
1055 if current != original {
1056 modified.push(*col);
1057 }
1058 }
1059
1060 modified
1061 }
1062
1063 pub fn object_state<M: Model + 'static>(&self, obj: &M) -> Option<ObjectState> {
1067 let key = ObjectKey::from_model(obj);
1068 self.identity_map.get(&key).map(|t| t.state)
1069 }
1070
1071 #[tracing::instrument(level = "debug", skip(self, obj), fields(table = M::TABLE_NAME))]
1105 pub fn expire<M: Model + 'static>(&mut self, obj: &M, attributes: Option<&[&str]>) {
1106 let key = ObjectKey::from_model(obj);
1107
1108 let Some(tracked) = self.identity_map.get_mut(&key) else {
1109 tracing::debug!("Object not tracked, nothing to expire");
1110 return;
1111 };
1112
1113 match tracked.state {
1115 ObjectState::New | ObjectState::Detached | ObjectState::Deleted => {
1116 tracing::debug!(state = ?tracked.state, "Cannot expire object in this state");
1117 return;
1118 }
1119 ObjectState::Persistent | ObjectState::Expired => {}
1120 }
1121
1122 match attributes {
1123 None => {
1124 tracked.state = ObjectState::Expired;
1126 tracked.expired_attributes = None;
1127 tracing::debug!("Expired all attributes");
1128 }
1129 Some(attrs) => {
1130 let mut expired = tracked.expired_attributes.take().unwrap_or_default();
1132 for attr in attrs {
1133 expired.insert((*attr).to_string());
1134 }
1135 tracked.expired_attributes = Some(expired);
1136
1137 if tracked.state == ObjectState::Persistent {
1139 tracked.state = ObjectState::Expired;
1140 }
1141 tracing::debug!(attributes = ?attrs, "Expired specific attributes");
1142 }
1143 }
1144 }
1145
1146 #[tracing::instrument(level = "debug", skip(self))]
1167 pub fn expire_all(&mut self) {
1168 let mut expired_count = 0;
1169 for tracked in self.identity_map.values_mut() {
1170 if tracked.state == ObjectState::Persistent {
1171 tracked.state = ObjectState::Expired;
1172 tracked.expired_attributes = None;
1173 expired_count += 1;
1174 }
1175 }
1176 tracing::debug!(count = expired_count, "Expired all session objects");
1177 }
1178
1179 pub fn is_expired<M: Model + 'static>(&self, obj: &M) -> bool {
1184 let key = ObjectKey::from_model(obj);
1185 self.identity_map
1186 .get(&key)
1187 .is_some_and(|t| t.state == ObjectState::Expired)
1188 }
1189
1190 pub fn expired_attributes<M: Model + 'static>(
1197 &self,
1198 obj: &M,
1199 ) -> Option<Option<&std::collections::HashSet<String>>> {
1200 let key = ObjectKey::from_model(obj);
1201 let tracked = self.identity_map.get(&key)?;
1202
1203 if tracked.state != ObjectState::Expired {
1204 return None;
1205 }
1206
1207 Some(tracked.expired_attributes.as_ref())
1208 }
1209
1210 #[tracing::instrument(level = "debug", skip(self, cx, obj), fields(table = M::TABLE_NAME))]
1245 pub async fn refresh<
1246 M: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
1247 >(
1248 &mut self,
1249 cx: &Cx,
1250 obj: &M,
1251 ) -> Outcome<Option<M>, Error> {
1252 let pk_values = obj.primary_key_value();
1253 let key = ObjectKey::from_model(obj);
1254
1255 tracing::debug!(pk = ?pk_values, "Refreshing object from database");
1256
1257 self.pending_dirty.retain(|k| k != &key);
1259
1260 self.identity_map.remove(&key);
1262
1263 let result = self.get_by_pk::<M>(cx, &pk_values).await;
1265
1266 match &result {
1267 Outcome::Ok(Some(_)) => {
1268 tracing::debug!("Object refreshed successfully");
1269 }
1270 Outcome::Ok(None) => {
1271 tracing::debug!("Object no longer exists in database");
1272 }
1273 _ => {}
1274 }
1275
1276 result
1277 }
1278
1279 pub async fn begin(&mut self, cx: &Cx) -> Outcome<(), Error> {
1285 if self.in_transaction {
1286 return Outcome::Ok(());
1287 }
1288
1289 match self.connection.execute(cx, "BEGIN", &[]).await {
1290 Outcome::Ok(_) => {
1291 self.in_transaction = true;
1292 Outcome::Ok(())
1293 }
1294 Outcome::Err(e) => Outcome::Err(e),
1295 Outcome::Cancelled(r) => Outcome::Cancelled(r),
1296 Outcome::Panicked(p) => Outcome::Panicked(p),
1297 }
1298 }
1299
1300 pub async fn flush(&mut self, cx: &Cx) -> Outcome<(), Error> {
1304 if let Err(e) = self.event_callbacks.fire(SessionEvent::BeforeFlush) {
1306 return Outcome::Err(e);
1307 }
1308
1309 if self.config.auto_begin && !self.in_transaction {
1311 match self.begin(cx).await {
1312 Outcome::Ok(()) => {}
1313 Outcome::Err(e) => return Outcome::Err(e),
1314 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1315 Outcome::Panicked(p) => return Outcome::Panicked(p),
1316 }
1317 }
1318
1319 let deletes: Vec<ObjectKey> = std::mem::take(&mut self.pending_delete);
1321 let mut actually_deleted: Vec<ObjectKey> = Vec::new();
1322 for key in &deletes {
1323 if let Some(tracked) = self.identity_map.get(key) {
1324 if tracked.state != ObjectState::Deleted {
1326 continue;
1327 }
1328
1329 if tracked.pk_columns.is_empty() || tracked.pk_values.is_empty() {
1331 tracing::warn!(
1332 table = tracked.table_name,
1333 "Skipping DELETE for object without primary key - cannot identify row"
1334 );
1335 continue;
1336 }
1337
1338 let where_parts: Vec<String> = tracked
1340 .pk_columns
1341 .iter()
1342 .enumerate()
1343 .map(|(i, col)| format!("\"{}\" = ${}", col, i + 1))
1344 .collect();
1345
1346 let sql = format!(
1347 "DELETE FROM \"{}\" WHERE {}",
1348 tracked.table_name,
1349 where_parts.join(" AND ")
1350 );
1351
1352 match self.connection.execute(cx, &sql, &tracked.pk_values).await {
1353 Outcome::Ok(_) => {
1354 actually_deleted.push(*key);
1355 }
1356 Outcome::Err(e) => {
1357 self.pending_delete = deletes
1360 .into_iter()
1361 .filter(|k| !actually_deleted.contains(k))
1362 .collect();
1363 for key in &actually_deleted {
1365 self.identity_map.remove(key);
1366 }
1367 return Outcome::Err(e);
1368 }
1369 Outcome::Cancelled(r) => {
1370 self.pending_delete = deletes
1372 .into_iter()
1373 .filter(|k| !actually_deleted.contains(k))
1374 .collect();
1375 for key in &actually_deleted {
1376 self.identity_map.remove(key);
1377 }
1378 return Outcome::Cancelled(r);
1379 }
1380 Outcome::Panicked(p) => {
1381 self.pending_delete = deletes
1383 .into_iter()
1384 .filter(|k| !actually_deleted.contains(k))
1385 .collect();
1386 for key in &actually_deleted {
1387 self.identity_map.remove(key);
1388 }
1389 return Outcome::Panicked(p);
1390 }
1391 }
1392 }
1393 }
1394
1395 for key in &actually_deleted {
1397 self.identity_map.remove(key);
1398 }
1399
1400 let inserts: Vec<ObjectKey> = std::mem::take(&mut self.pending_new);
1402 for key in &inserts {
1403 if let Some(tracked) = self.identity_map.get_mut(key) {
1404 if tracked.state == ObjectState::Persistent {
1406 continue;
1407 }
1408
1409 let columns = &tracked.column_names;
1411 let placeholders: Vec<String> =
1412 (1..=columns.len()).map(|i| format!("${}", i)).collect();
1413
1414 let sql = format!(
1415 "INSERT INTO \"{}\" ({}) VALUES ({})",
1416 tracked.table_name,
1417 columns
1418 .iter()
1419 .map(|c| format!("\"{}\"", c))
1420 .collect::<Vec<_>>()
1421 .join(", "),
1422 placeholders.join(", ")
1423 );
1424
1425 match self.connection.execute(cx, &sql, &tracked.values).await {
1426 Outcome::Ok(_) => {
1427 tracked.state = ObjectState::Persistent;
1428 tracked.original_state =
1430 Some(serde_json::to_vec(&tracked.values).unwrap_or_default());
1431 }
1432 Outcome::Err(e) => {
1433 self.pending_new = inserts;
1435 return Outcome::Err(e);
1436 }
1437 Outcome::Cancelled(r) => {
1438 self.pending_new = inserts;
1440 return Outcome::Cancelled(r);
1441 }
1442 Outcome::Panicked(p) => {
1443 self.pending_new = inserts;
1445 return Outcome::Panicked(p);
1446 }
1447 }
1448 }
1449 }
1450
1451 let dirty: Vec<ObjectKey> = std::mem::take(&mut self.pending_dirty);
1453 for key in &dirty {
1454 if let Some(tracked) = self.identity_map.get_mut(key) {
1455 if tracked.state != ObjectState::Persistent {
1457 continue;
1458 }
1459
1460 if tracked.pk_columns.is_empty() || tracked.pk_values.is_empty() {
1462 tracing::warn!(
1463 table = tracked.table_name,
1464 "Skipping UPDATE for object without primary key - cannot identify row"
1465 );
1466 continue;
1467 }
1468
1469 let current_state = serde_json::to_vec(&tracked.values).unwrap_or_default();
1471 let is_dirty = tracked.original_state.as_ref() != Some(¤t_state);
1472
1473 if !is_dirty {
1474 continue;
1475 }
1476
1477 let mut set_parts = Vec::new();
1479 let mut params = Vec::new();
1480 let mut param_idx = 1;
1481
1482 for (i, col) in tracked.column_names.iter().enumerate() {
1483 if !tracked.pk_columns.contains(col) {
1485 set_parts.push(format!("\"{}\" = ${}", col, param_idx));
1486 params.push(tracked.values[i].clone());
1487 param_idx += 1;
1488 }
1489 }
1490
1491 let where_parts: Vec<String> = tracked
1493 .pk_columns
1494 .iter()
1495 .map(|col| {
1496 let clause = format!("\"{}\" = ${}", col, param_idx);
1497 param_idx += 1;
1498 clause
1499 })
1500 .collect();
1501
1502 params.extend(tracked.pk_values.clone());
1504
1505 if set_parts.is_empty() {
1506 continue; }
1508
1509 let sql = format!(
1510 "UPDATE \"{}\" SET {} WHERE {}",
1511 tracked.table_name,
1512 set_parts.join(", "),
1513 where_parts.join(" AND ")
1514 );
1515
1516 match self.connection.execute(cx, &sql, ¶ms).await {
1517 Outcome::Ok(_) => {
1518 tracked.original_state = Some(current_state);
1520 }
1521 Outcome::Err(e) => {
1522 self.pending_dirty = dirty;
1524 return Outcome::Err(e);
1525 }
1526 Outcome::Cancelled(r) => {
1527 self.pending_dirty = dirty;
1529 return Outcome::Cancelled(r);
1530 }
1531 Outcome::Panicked(p) => {
1532 self.pending_dirty = dirty;
1534 return Outcome::Panicked(p);
1535 }
1536 }
1537 }
1538 }
1539
1540 if let Err(e) = self.event_callbacks.fire(SessionEvent::AfterFlush) {
1542 return Outcome::Err(e);
1543 }
1544
1545 Outcome::Ok(())
1546 }
1547
1548 pub async fn commit(&mut self, cx: &Cx) -> Outcome<(), Error> {
1550 match self.flush(cx).await {
1552 Outcome::Ok(()) => {}
1553 Outcome::Err(e) => return Outcome::Err(e),
1554 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1555 Outcome::Panicked(p) => return Outcome::Panicked(p),
1556 }
1557
1558 if let Err(e) = self.event_callbacks.fire(SessionEvent::BeforeCommit) {
1560 return Outcome::Err(e);
1561 }
1562
1563 if self.in_transaction {
1564 match self.connection.execute(cx, "COMMIT", &[]).await {
1565 Outcome::Ok(_) => {
1566 self.in_transaction = false;
1567 }
1568 Outcome::Err(e) => return Outcome::Err(e),
1569 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1570 Outcome::Panicked(p) => return Outcome::Panicked(p),
1571 }
1572 }
1573
1574 if self.config.expire_on_commit {
1576 for tracked in self.identity_map.values_mut() {
1577 if tracked.state == ObjectState::Persistent {
1578 tracked.state = ObjectState::Expired;
1579 }
1580 }
1581 }
1582
1583 if let Err(e) = self.event_callbacks.fire(SessionEvent::AfterCommit) {
1585 return Outcome::Err(e);
1586 }
1587
1588 Outcome::Ok(())
1589 }
1590
1591 pub async fn rollback(&mut self, cx: &Cx) -> Outcome<(), Error> {
1593 if self.in_transaction {
1594 match self.connection.execute(cx, "ROLLBACK", &[]).await {
1595 Outcome::Ok(_) => {
1596 self.in_transaction = false;
1597 }
1598 Outcome::Err(e) => return Outcome::Err(e),
1599 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1600 Outcome::Panicked(p) => return Outcome::Panicked(p),
1601 }
1602 }
1603
1604 self.pending_new.clear();
1606 self.pending_delete.clear();
1607 self.pending_dirty.clear();
1608
1609 let mut to_remove = Vec::new();
1611 for (key, tracked) in &mut self.identity_map {
1612 match tracked.state {
1613 ObjectState::New => {
1614 to_remove.push(*key);
1615 }
1616 ObjectState::Deleted => {
1617 tracked.state = ObjectState::Persistent;
1618 }
1619 _ => {}
1620 }
1621 }
1622
1623 for key in to_remove {
1624 self.identity_map.remove(&key);
1625 }
1626
1627 if let Err(e) = self.event_callbacks.fire(SessionEvent::AfterRollback) {
1629 return Outcome::Err(e);
1630 }
1631
1632 Outcome::Ok(())
1633 }
1634
1635 #[tracing::instrument(level = "debug", skip(self, lazy, cx))]
1651 pub async fn load_lazy<
1652 T: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
1653 >(
1654 &mut self,
1655 lazy: &Lazy<T>,
1656 cx: &Cx,
1657 ) -> Outcome<bool, Error> {
1658 tracing::debug!(
1659 model = std::any::type_name::<T>(),
1660 fk = ?lazy.fk(),
1661 already_loaded = lazy.is_loaded(),
1662 "Loading lazy relationship"
1663 );
1664
1665 if lazy.is_loaded() {
1667 tracing::trace!("Already loaded");
1668 return Outcome::Ok(lazy.get().is_some());
1669 }
1670
1671 let Some(fk) = lazy.fk() else {
1673 let _ = lazy.set_loaded(None);
1674 return Outcome::Ok(false);
1675 };
1676
1677 let obj = match self.get::<T>(cx, fk.clone()).await {
1679 Outcome::Ok(obj) => obj,
1680 Outcome::Err(e) => return Outcome::Err(e),
1681 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1682 Outcome::Panicked(p) => return Outcome::Panicked(p),
1683 };
1684
1685 let found = obj.is_some();
1686
1687 let _ = lazy.set_loaded(obj);
1689
1690 tracing::debug!(found = found, "Lazy load complete");
1691
1692 Outcome::Ok(found)
1693 }
1694
1695 #[tracing::instrument(level = "debug", skip(self, cx, objects, accessor))]
1719 pub async fn load_many<P, T, F>(
1720 &mut self,
1721 cx: &Cx,
1722 objects: &[P],
1723 accessor: F,
1724 ) -> Outcome<usize, Error>
1725 where
1726 P: Model + 'static,
1727 T: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
1728 F: Fn(&P) -> &Lazy<T>,
1729 {
1730 let mut fk_values: Vec<Value> = Vec::new();
1732 let mut fk_indices: Vec<usize> = Vec::new();
1733
1734 for (idx, obj) in objects.iter().enumerate() {
1735 let lazy = accessor(obj);
1736 if !lazy.is_loaded() && !lazy.is_empty() {
1737 if let Some(fk) = lazy.fk() {
1738 fk_values.push(fk.clone());
1739 fk_indices.push(idx);
1740 }
1741 }
1742 }
1743
1744 let fk_count = fk_values.len();
1745 tracing::info!(
1746 parent_model = std::any::type_name::<P>(),
1747 related_model = std::any::type_name::<T>(),
1748 parent_count = objects.len(),
1749 fk_count = fk_count,
1750 "Batch loading lazy relationships"
1751 );
1752
1753 if fk_values.is_empty() {
1754 for obj in objects {
1756 let lazy = accessor(obj);
1757 if !lazy.is_loaded() && lazy.is_empty() {
1758 let _ = lazy.set_loaded(None);
1759 }
1760 }
1761 return Outcome::Ok(0);
1762 }
1763
1764 let pk_col = T::PRIMARY_KEY.first().unwrap_or(&"id");
1766 let placeholders: Vec<String> = (1..=fk_values.len()).map(|i| format!("${}", i)).collect();
1767 let sql = format!(
1768 "SELECT * FROM \"{}\" WHERE \"{}\" IN ({})",
1769 T::TABLE_NAME,
1770 pk_col,
1771 placeholders.join(", ")
1772 );
1773
1774 let rows = match self.connection.query(cx, &sql, &fk_values).await {
1775 Outcome::Ok(rows) => rows,
1776 Outcome::Err(e) => return Outcome::Err(e),
1777 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1778 Outcome::Panicked(p) => return Outcome::Panicked(p),
1779 };
1780
1781 let mut lookup: HashMap<u64, T> = HashMap::new();
1783 for row in &rows {
1784 match T::from_row(row) {
1785 Ok(obj) => {
1786 let pk_values = obj.primary_key_value();
1787 let pk_hash = hash_values(&pk_values);
1788
1789 let key = ObjectKey::from_pk::<T>(&pk_values);
1791
1792 let row_data = obj.to_row();
1794 let column_names: Vec<&'static str> =
1795 row_data.iter().map(|(name, _)| *name).collect();
1796 let values: Vec<Value> = row_data.into_iter().map(|(_, v)| v).collect();
1797
1798 let serialized = serde_json::to_vec(&values).ok();
1800
1801 let tracked = TrackedObject {
1802 object: Box::new(obj.clone()),
1803 original_state: serialized,
1804 state: ObjectState::Persistent,
1805 table_name: T::TABLE_NAME,
1806 column_names,
1807 values,
1808 pk_columns: T::PRIMARY_KEY.to_vec(),
1809 pk_values: pk_values.clone(),
1810 expired_attributes: None,
1811 };
1812 self.identity_map.insert(key, tracked);
1813
1814 lookup.insert(pk_hash, obj);
1816 }
1817 Err(_) => continue,
1818 }
1819 }
1820
1821 let mut loaded_count = 0;
1823 for obj in objects {
1824 let lazy = accessor(obj);
1825 if !lazy.is_loaded() {
1826 if let Some(fk) = lazy.fk() {
1827 let fk_hash = hash_values(std::slice::from_ref(fk));
1828 let related = lookup.get(&fk_hash).cloned();
1829 let found = related.is_some();
1830 let _ = lazy.set_loaded(related);
1831 if found {
1832 loaded_count += 1;
1833 }
1834 } else {
1835 let _ = lazy.set_loaded(None);
1836 }
1837 }
1838 }
1839
1840 tracing::debug!(
1841 query_count = 1,
1842 loaded_count = loaded_count,
1843 "Batch load complete"
1844 );
1845
1846 Outcome::Ok(loaded_count)
1847 }
1848
1849 #[tracing::instrument(level = "debug", skip(self, cx, objects, accessor, parent_pk))]
1873 pub async fn load_many_to_many<P, Child, FA, FP>(
1874 &mut self,
1875 cx: &Cx,
1876 objects: &mut [P],
1877 accessor: FA,
1878 parent_pk: FP,
1879 link_table: &sqlmodel_core::LinkTableInfo,
1880 ) -> Outcome<usize, Error>
1881 where
1882 P: Model + 'static,
1883 Child: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
1884 FA: Fn(&mut P) -> &mut sqlmodel_core::RelatedMany<Child>,
1885 FP: Fn(&P) -> Value,
1886 {
1887 let pks: Vec<Value> = objects.iter().map(&parent_pk).collect();
1889
1890 tracing::info!(
1891 parent_model = std::any::type_name::<P>(),
1892 related_model = std::any::type_name::<Child>(),
1893 parent_count = pks.len(),
1894 link_table = link_table.table_name,
1895 "Batch loading many-to-many relationships"
1896 );
1897
1898 if pks.is_empty() {
1899 return Outcome::Ok(0);
1900 }
1901
1902 let child_pk_col = Child::PRIMARY_KEY.first().unwrap_or(&"id");
1908 let placeholders: Vec<String> = (1..=pks.len()).map(|i| format!("${}", i)).collect();
1909 let sql = format!(
1910 "SELECT \"{}\".*, \"{}\".\"{}\" AS __parent_pk FROM \"{}\" \
1911 JOIN \"{}\" ON \"{}\".\"{}\" = \"{}\".\"{}\" \
1912 WHERE \"{}\".\"{}\" IN ({})",
1913 Child::TABLE_NAME,
1914 link_table.table_name,
1915 link_table.local_column,
1916 Child::TABLE_NAME,
1917 link_table.table_name,
1918 Child::TABLE_NAME,
1919 child_pk_col,
1920 link_table.table_name,
1921 link_table.remote_column,
1922 link_table.table_name,
1923 link_table.local_column,
1924 placeholders.join(", ")
1925 );
1926
1927 tracing::trace!(sql = %sql, "Many-to-many batch SQL");
1928
1929 let rows = match self.connection.query(cx, &sql, &pks).await {
1930 Outcome::Ok(rows) => rows,
1931 Outcome::Err(e) => return Outcome::Err(e),
1932 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1933 Outcome::Panicked(p) => return Outcome::Panicked(p),
1934 };
1935
1936 let mut by_parent: HashMap<u64, Vec<Child>> = HashMap::new();
1938 for row in &rows {
1939 let parent_pk_value: Value = match row.get_by_name("__parent_pk") {
1941 Some(v) => v.clone(),
1942 None => continue,
1943 };
1944 let parent_pk_hash = hash_values(std::slice::from_ref(&parent_pk_value));
1945
1946 match Child::from_row(row) {
1948 Ok(child) => {
1949 by_parent.entry(parent_pk_hash).or_default().push(child);
1950 }
1951 Err(_) => continue,
1952 }
1953 }
1954
1955 let mut loaded_count = 0;
1957 for obj in objects {
1958 let pk = parent_pk(obj);
1959 let pk_hash = hash_values(std::slice::from_ref(&pk));
1960 let children = by_parent.remove(&pk_hash).unwrap_or_default();
1961 let child_count = children.len();
1962
1963 let related = accessor(obj);
1964 related.set_parent_pk(pk);
1965 let _ = related.set_loaded(children);
1966 loaded_count += child_count;
1967 }
1968
1969 tracing::debug!(
1970 query_count = 1,
1971 total_children = loaded_count,
1972 "Many-to-many batch load complete"
1973 );
1974
1975 Outcome::Ok(loaded_count)
1976 }
1977
1978 #[tracing::instrument(level = "debug", skip(self, cx, objects, accessor, parent_pk))]
1997 pub async fn flush_related_many<P, Child, FA, FP>(
1998 &mut self,
1999 cx: &Cx,
2000 objects: &mut [P],
2001 accessor: FA,
2002 parent_pk: FP,
2003 link_table: &sqlmodel_core::LinkTableInfo,
2004 ) -> Outcome<usize, Error>
2005 where
2006 P: Model + 'static,
2007 Child: Model + 'static,
2008 FA: Fn(&mut P) -> &mut sqlmodel_core::RelatedMany<Child>,
2009 FP: Fn(&P) -> Value,
2010 {
2011 let mut ops = Vec::new();
2012
2013 for obj in objects.iter_mut() {
2015 let parent_pk_value = parent_pk(obj);
2016 let related = accessor(obj);
2017
2018 for child_pk_values in related.take_pending_links() {
2020 if let Some(child_pk) = child_pk_values.first() {
2021 ops.push(LinkTableOp::link(
2022 link_table.table_name.to_string(),
2023 link_table.local_column.to_string(),
2024 parent_pk_value.clone(),
2025 link_table.remote_column.to_string(),
2026 child_pk.clone(),
2027 ));
2028 }
2029 }
2030
2031 for child_pk_values in related.take_pending_unlinks() {
2033 if let Some(child_pk) = child_pk_values.first() {
2034 ops.push(LinkTableOp::unlink(
2035 link_table.table_name.to_string(),
2036 link_table.local_column.to_string(),
2037 parent_pk_value.clone(),
2038 link_table.remote_column.to_string(),
2039 child_pk.clone(),
2040 ));
2041 }
2042 }
2043 }
2044
2045 if ops.is_empty() {
2046 return Outcome::Ok(0);
2047 }
2048
2049 tracing::info!(
2050 parent_model = std::any::type_name::<P>(),
2051 related_model = std::any::type_name::<Child>(),
2052 link_count = ops
2053 .iter()
2054 .filter(|o| matches!(o, LinkTableOp::Link { .. }))
2055 .count(),
2056 unlink_count = ops
2057 .iter()
2058 .filter(|o| matches!(o, LinkTableOp::Unlink { .. }))
2059 .count(),
2060 link_table = link_table.table_name,
2061 "Flushing many-to-many relationship changes"
2062 );
2063
2064 flush::execute_link_table_ops(cx, &self.connection, &ops).await
2065 }
2066
2067 pub fn relate_to_one<Child, Parent, FC, FP, FK>(
2092 &self,
2093 child: &mut Child,
2094 child_accessor: FC,
2095 set_fk: FK,
2096 parent: &mut Parent,
2097 parent_accessor: FP,
2098 ) where
2099 Child: Model + Clone + 'static,
2100 Parent: Model + Clone + 'static,
2101 FC: FnOnce(&mut Child) -> &mut sqlmodel_core::Related<Parent>,
2102 FP: FnOnce(&mut Parent) -> &mut sqlmodel_core::RelatedMany<Child>,
2103 FK: FnOnce(&mut Child),
2104 {
2105 let related = child_accessor(child);
2107 let _ = related.set_loaded(Some(parent.clone()));
2108
2109 set_fk(child);
2111
2112 let related_many = parent_accessor(parent);
2114 related_many.link(child);
2115
2116 tracing::debug!(
2117 child_model = std::any::type_name::<Child>(),
2118 parent_model = std::any::type_name::<Parent>(),
2119 "Established bidirectional ManyToOne <-> OneToMany relationship"
2120 );
2121 }
2122
2123 pub fn unrelate_from_one<Child, Parent, FC, FP, FK>(
2139 &self,
2140 child: &mut Child,
2141 child_accessor: FC,
2142 clear_fk: FK,
2143 parent: &mut Parent,
2144 parent_accessor: FP,
2145 ) where
2146 Child: Model + Clone + 'static,
2147 Parent: Model + Clone + 'static,
2148 FC: FnOnce(&mut Child) -> &mut sqlmodel_core::Related<Parent>,
2149 FP: FnOnce(&mut Parent) -> &mut sqlmodel_core::RelatedMany<Child>,
2150 FK: FnOnce(&mut Child),
2151 {
2152 let related = child_accessor(child);
2154 *related = sqlmodel_core::Related::empty();
2155
2156 clear_fk(child);
2158
2159 let related_many = parent_accessor(parent);
2161 related_many.unlink(child);
2162
2163 tracing::debug!(
2164 child_model = std::any::type_name::<Child>(),
2165 parent_model = std::any::type_name::<Parent>(),
2166 "Removed bidirectional ManyToOne <-> OneToMany relationship"
2167 );
2168 }
2169
2170 pub fn relate_many_to_many<Left, Right, FL, FR>(
2189 &self,
2190 left: &mut Left,
2191 left_accessor: FL,
2192 right: &mut Right,
2193 right_accessor: FR,
2194 ) where
2195 Left: Model + Clone + 'static,
2196 Right: Model + Clone + 'static,
2197 FL: FnOnce(&mut Left) -> &mut sqlmodel_core::RelatedMany<Right>,
2198 FR: FnOnce(&mut Right) -> &mut sqlmodel_core::RelatedMany<Left>,
2199 {
2200 let left_coll = left_accessor(left);
2202 left_coll.link(right);
2203
2204 let right_coll = right_accessor(right);
2206 right_coll.link(left);
2207
2208 tracing::debug!(
2209 left_model = std::any::type_name::<Left>(),
2210 right_model = std::any::type_name::<Right>(),
2211 "Established bidirectional ManyToMany relationship"
2212 );
2213 }
2214
2215 pub fn unrelate_many_to_many<Left, Right, FL, FR>(
2219 &self,
2220 left: &mut Left,
2221 left_accessor: FL,
2222 right: &mut Right,
2223 right_accessor: FR,
2224 ) where
2225 Left: Model + Clone + 'static,
2226 Right: Model + Clone + 'static,
2227 FL: FnOnce(&mut Left) -> &mut sqlmodel_core::RelatedMany<Right>,
2228 FR: FnOnce(&mut Right) -> &mut sqlmodel_core::RelatedMany<Left>,
2229 {
2230 let left_coll = left_accessor(left);
2232 left_coll.unlink(right);
2233
2234 let right_coll = right_accessor(right);
2236 right_coll.unlink(left);
2237
2238 tracing::debug!(
2239 left_model = std::any::type_name::<Left>(),
2240 right_model = std::any::type_name::<Right>(),
2241 "Removed bidirectional ManyToMany relationship"
2242 );
2243 }
2244
2245 pub fn enable_n1_detection(&mut self, threshold: usize) {
2270 self.n1_tracker = Some(N1QueryTracker::new().with_threshold(threshold));
2271 }
2272
2273 pub fn disable_n1_detection(&mut self) {
2275 self.n1_tracker = None;
2276 }
2277
2278 #[must_use]
2280 pub fn n1_detection_enabled(&self) -> bool {
2281 self.n1_tracker.is_some()
2282 }
2283
2284 pub fn n1_tracker_mut(&mut self) -> Option<&mut N1QueryTracker> {
2286 self.n1_tracker.as_mut()
2287 }
2288
2289 #[must_use]
2291 pub fn n1_stats(&self) -> Option<N1Stats> {
2292 self.n1_tracker.as_ref().map(|t| t.stats())
2293 }
2294
2295 pub fn reset_n1_tracking(&mut self) {
2297 if let Some(tracker) = &mut self.n1_tracker {
2298 tracker.reset();
2299 }
2300 }
2301
2302 #[track_caller]
2306 pub fn record_lazy_load(&mut self, parent_type: &'static str, relationship: &'static str) {
2307 if let Some(tracker) = &mut self.n1_tracker {
2308 tracker.record_load(parent_type, relationship);
2309 }
2310 }
2311
2312 #[tracing::instrument(level = "debug", skip(self, cx, model), fields(table = M::TABLE_NAME))]
2356 pub async fn merge<
2357 M: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
2358 >(
2359 &mut self,
2360 cx: &Cx,
2361 model: M,
2362 load: bool,
2363 ) -> Outcome<M, Error> {
2364 let pk_values = model.primary_key_value();
2365 let key = ObjectKey::from_model(&model);
2366
2367 tracing::debug!(
2368 pk = ?pk_values,
2369 load = load,
2370 in_identity_map = self.identity_map.contains_key(&key),
2371 "Merging object"
2372 );
2373
2374 if let Some(tracked) = self.identity_map.get_mut(&key) {
2376 if tracked.state == ObjectState::Detached {
2378 tracing::debug!("Found detached object, treating as new");
2379 } else {
2380 tracing::debug!(
2381 state = ?tracked.state,
2382 "Found tracked object, updating with merged values"
2383 );
2384
2385 let row_data = model.to_row();
2387 tracked.object = Box::new(model.clone());
2388 tracked.column_names = row_data.iter().map(|(name, _)| *name).collect();
2389 tracked.values = row_data.into_iter().map(|(_, v)| v).collect();
2390 tracked.pk_values.clone_from(&pk_values);
2391
2392 if tracked.state == ObjectState::Persistent && !self.pending_dirty.contains(&key) {
2394 self.pending_dirty.push(key);
2395 }
2396
2397 if let Some(obj) = tracked.object.downcast_ref::<M>() {
2399 return Outcome::Ok(obj.clone());
2400 }
2401 }
2402 }
2403
2404 if load {
2406 let has_valid_pk = pk_values
2408 .iter()
2409 .all(|v| !matches!(v, Value::Null | Value::Default));
2410
2411 if has_valid_pk {
2412 tracing::debug!("Loading from database");
2413
2414 let db_result = self.get_by_pk::<M>(cx, &pk_values).await;
2415 match db_result {
2416 Outcome::Ok(Some(_existing)) => {
2417 if let Some(tracked) = self.identity_map.get_mut(&key) {
2420 let row_data = model.to_row();
2421 tracked.object = Box::new(model.clone());
2422 tracked.column_names = row_data.iter().map(|(name, _)| *name).collect();
2423 tracked.values = row_data.into_iter().map(|(_, v)| v).collect();
2424 if !self.pending_dirty.contains(&key) {
2428 self.pending_dirty.push(key);
2429 }
2430
2431 tracing::debug!("Merged values onto DB object");
2432
2433 if let Some(obj) = tracked.object.downcast_ref::<M>() {
2434 return Outcome::Ok(obj.clone());
2435 }
2436 }
2437 }
2438 Outcome::Ok(None) => {
2439 tracing::debug!("Object not found in database, treating as new");
2440 }
2441 Outcome::Err(e) => return Outcome::Err(e),
2442 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
2443 Outcome::Panicked(p) => return Outcome::Panicked(p),
2444 }
2445 }
2446 }
2447
2448 tracing::debug!("Adding as new object");
2450 self.add(&model);
2451
2452 Outcome::Ok(model)
2453 }
2454
2455 pub async fn merge_without_load<
2467 M: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
2468 >(
2469 &mut self,
2470 cx: &Cx,
2471 model: M,
2472 ) -> Outcome<M, Error> {
2473 self.merge(cx, model, false).await
2474 }
2475
2476 pub fn pending_new_count(&self) -> usize {
2482 self.pending_new.len()
2483 }
2484
2485 pub fn pending_delete_count(&self) -> usize {
2487 self.pending_delete.len()
2488 }
2489
2490 pub fn pending_dirty_count(&self) -> usize {
2492 self.pending_dirty.len()
2493 }
2494
2495 pub fn tracked_count(&self) -> usize {
2497 self.identity_map.len()
2498 }
2499
2500 pub fn in_transaction(&self) -> bool {
2502 self.in_transaction
2503 }
2504
2505 pub fn debug_state(&self) -> SessionDebugInfo {
2507 SessionDebugInfo {
2508 tracked: self.tracked_count(),
2509 pending_new: self.pending_new_count(),
2510 pending_delete: self.pending_delete_count(),
2511 pending_dirty: self.pending_dirty_count(),
2512 in_transaction: self.in_transaction,
2513 }
2514 }
2515
2516 pub async fn bulk_insert<M: Model + Clone + Send + Sync + 'static>(
2530 &mut self,
2531 cx: &Cx,
2532 models: &[M],
2533 ) -> Outcome<u64, Error> {
2534 self.bulk_insert_with_batch_size(cx, models, 1000).await
2535 }
2536
2537 pub async fn bulk_insert_with_batch_size<M: Model + Clone + Send + Sync + 'static>(
2539 &mut self,
2540 cx: &Cx,
2541 models: &[M],
2542 batch_size: usize,
2543 ) -> Outcome<u64, Error> {
2544 if models.is_empty() {
2545 return Outcome::Ok(0);
2546 }
2547
2548 let batch_size = batch_size.max(1);
2549 let mut total_inserted: u64 = 0;
2550
2551 for chunk in models.chunks(batch_size) {
2552 let builder = sqlmodel_query::InsertManyBuilder::new(chunk);
2553 let (sql, params) = builder.build();
2554
2555 if sql.is_empty() {
2556 continue;
2557 }
2558
2559 match self.connection.execute(cx, &sql, ¶ms).await {
2560 Outcome::Ok(count) => total_inserted += count,
2561 Outcome::Err(e) => return Outcome::Err(e),
2562 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
2563 Outcome::Panicked(p) => return Outcome::Panicked(p),
2564 }
2565 }
2566
2567 Outcome::Ok(total_inserted)
2568 }
2569
2570 pub async fn bulk_update<M: Model + Clone + Send + Sync + 'static>(
2578 &mut self,
2579 cx: &Cx,
2580 models: &[M],
2581 ) -> Outcome<u64, Error> {
2582 if models.is_empty() {
2583 return Outcome::Ok(0);
2584 }
2585
2586 let mut total_updated: u64 = 0;
2587
2588 for model in models {
2589 let builder = sqlmodel_query::UpdateBuilder::new(model);
2590 let (sql, params) = builder.build();
2591
2592 if sql.is_empty() {
2593 continue;
2594 }
2595
2596 match self.connection.execute(cx, &sql, ¶ms).await {
2597 Outcome::Ok(count) => total_updated += count,
2598 Outcome::Err(e) => return Outcome::Err(e),
2599 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
2600 Outcome::Panicked(p) => return Outcome::Panicked(p),
2601 }
2602 }
2603
2604 Outcome::Ok(total_updated)
2605 }
2606}
2607
2608impl<C, M> LazyLoader<M> for Session<C>
2609where
2610 C: Connection,
2611 M: Model + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
2612{
2613 fn get(
2614 &mut self,
2615 cx: &Cx,
2616 pk: Value,
2617 ) -> impl Future<Output = Outcome<Option<M>, Error>> + Send {
2618 Session::get(self, cx, pk)
2619 }
2620}
2621
2622#[derive(Debug, Clone)]
2624pub struct SessionDebugInfo {
2625 pub tracked: usize,
2627 pub pending_new: usize,
2629 pub pending_delete: usize,
2631 pub pending_dirty: usize,
2633 pub in_transaction: bool,
2635}
2636
2637#[cfg(test)]
2642#[allow(clippy::manual_async_fn)] mod tests {
2644 use super::*;
2645 use asupersync::runtime::RuntimeBuilder;
2646 use sqlmodel_core::Row;
2647 use std::sync::{Arc, Mutex};
2648
2649 #[test]
2650 fn test_session_config_defaults() {
2651 let config = SessionConfig::default();
2652 assert!(config.auto_begin);
2653 assert!(!config.auto_flush);
2654 assert!(config.expire_on_commit);
2655 }
2656
2657 #[test]
2658 fn test_object_key_hash_consistency() {
2659 let values1 = vec![Value::BigInt(42)];
2660 let values2 = vec![Value::BigInt(42)];
2661 let hash1 = hash_values(&values1);
2662 let hash2 = hash_values(&values2);
2663 assert_eq!(hash1, hash2);
2664 }
2665
2666 #[test]
2667 fn test_object_key_hash_different_values() {
2668 let values1 = vec![Value::BigInt(42)];
2669 let values2 = vec![Value::BigInt(43)];
2670 let hash1 = hash_values(&values1);
2671 let hash2 = hash_values(&values2);
2672 assert_ne!(hash1, hash2);
2673 }
2674
2675 #[test]
2676 fn test_object_key_hash_different_types() {
2677 let values1 = vec![Value::BigInt(42)];
2678 let values2 = vec![Value::Text("42".to_string())];
2679 let hash1 = hash_values(&values1);
2680 let hash2 = hash_values(&values2);
2681 assert_ne!(hash1, hash2);
2682 }
2683
2684 #[test]
2685 fn test_session_debug_info() {
2686 let info = SessionDebugInfo {
2687 tracked: 5,
2688 pending_new: 2,
2689 pending_delete: 1,
2690 pending_dirty: 0,
2691 in_transaction: true,
2692 };
2693 assert_eq!(info.tracked, 5);
2694 assert_eq!(info.pending_new, 2);
2695 assert!(info.in_transaction);
2696 }
2697
2698 fn unwrap_outcome<T: std::fmt::Debug>(outcome: Outcome<T, Error>) -> T {
2699 match outcome {
2700 Outcome::Ok(v) => v,
2701 other => std::panic::panic_any(format!("unexpected outcome: {other:?}")),
2702 }
2703 }
2704
2705 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2706 struct Team {
2707 id: Option<i64>,
2708 name: String,
2709 }
2710
2711 impl Model for Team {
2712 const TABLE_NAME: &'static str = "teams";
2713 const PRIMARY_KEY: &'static [&'static str] = &["id"];
2714
2715 fn fields() -> &'static [sqlmodel_core::FieldInfo] {
2716 &[]
2717 }
2718
2719 fn to_row(&self) -> Vec<(&'static str, Value)> {
2720 vec![
2721 ("id", self.id.map_or(Value::Null, Value::BigInt)),
2722 ("name", Value::Text(self.name.clone())),
2723 ]
2724 }
2725
2726 fn from_row(row: &Row) -> sqlmodel_core::Result<Self> {
2727 let id: i64 = row.get_named("id")?;
2728 let name: String = row.get_named("name")?;
2729 Ok(Self { id: Some(id), name })
2730 }
2731
2732 fn primary_key_value(&self) -> Vec<Value> {
2733 self.id
2734 .map_or_else(|| vec![Value::Null], |id| vec![Value::BigInt(id)])
2735 }
2736
2737 fn is_new(&self) -> bool {
2738 self.id.is_none()
2739 }
2740 }
2741
2742 #[derive(Debug, Clone, Serialize, Deserialize)]
2743 struct Hero {
2744 id: Option<i64>,
2745 team: Lazy<Team>,
2746 }
2747
2748 impl Model for Hero {
2749 const TABLE_NAME: &'static str = "heroes";
2750 const PRIMARY_KEY: &'static [&'static str] = &["id"];
2751
2752 fn fields() -> &'static [sqlmodel_core::FieldInfo] {
2753 &[]
2754 }
2755
2756 fn to_row(&self) -> Vec<(&'static str, Value)> {
2757 vec![]
2758 }
2759
2760 fn from_row(_row: &Row) -> sqlmodel_core::Result<Self> {
2761 Ok(Self {
2762 id: None,
2763 team: Lazy::empty(),
2764 })
2765 }
2766
2767 fn primary_key_value(&self) -> Vec<Value> {
2768 self.id
2769 .map_or_else(|| vec![Value::Null], |id| vec![Value::BigInt(id)])
2770 }
2771
2772 fn is_new(&self) -> bool {
2773 self.id.is_none()
2774 }
2775 }
2776
2777 #[derive(Debug, Default)]
2778 struct MockState {
2779 query_calls: usize,
2780 }
2781
2782 #[derive(Debug, Clone)]
2783 struct MockConnection {
2784 state: Arc<Mutex<MockState>>,
2785 }
2786
2787 impl sqlmodel_core::Connection for MockConnection {
2788 type Tx<'conn>
2789 = MockTransaction
2790 where
2791 Self: 'conn;
2792
2793 fn query(
2794 &self,
2795 _cx: &Cx,
2796 _sql: &str,
2797 params: &[Value],
2798 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
2799 let params = params.to_vec();
2800 let state = Arc::clone(&self.state);
2801 async move {
2802 state.lock().expect("lock poisoned").query_calls += 1;
2803
2804 let mut rows = Vec::new();
2805 for v in params {
2806 match v {
2807 Value::BigInt(1) => rows.push(Row::new(
2808 vec!["id".into(), "name".into()],
2809 vec![Value::BigInt(1), Value::Text("Avengers".into())],
2810 )),
2811 Value::BigInt(2) => rows.push(Row::new(
2812 vec!["id".into(), "name".into()],
2813 vec![Value::BigInt(2), Value::Text("X-Men".into())],
2814 )),
2815 _ => {}
2816 }
2817 }
2818
2819 Outcome::Ok(rows)
2820 }
2821 }
2822
2823 fn query_one(
2824 &self,
2825 _cx: &Cx,
2826 _sql: &str,
2827 _params: &[Value],
2828 ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
2829 async { Outcome::Ok(None) }
2830 }
2831
2832 fn execute(
2833 &self,
2834 _cx: &Cx,
2835 _sql: &str,
2836 _params: &[Value],
2837 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
2838 async { Outcome::Ok(0) }
2839 }
2840
2841 fn insert(
2842 &self,
2843 _cx: &Cx,
2844 _sql: &str,
2845 _params: &[Value],
2846 ) -> impl Future<Output = Outcome<i64, Error>> + Send {
2847 async { Outcome::Ok(0) }
2848 }
2849
2850 fn batch(
2851 &self,
2852 _cx: &Cx,
2853 _statements: &[(String, Vec<Value>)],
2854 ) -> impl Future<Output = Outcome<Vec<u64>, Error>> + Send {
2855 async { Outcome::Ok(vec![]) }
2856 }
2857
2858 fn begin(&self, _cx: &Cx) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
2859 async { Outcome::Ok(MockTransaction) }
2860 }
2861
2862 fn begin_with(
2863 &self,
2864 _cx: &Cx,
2865 _isolation: sqlmodel_core::connection::IsolationLevel,
2866 ) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
2867 async { Outcome::Ok(MockTransaction) }
2868 }
2869
2870 fn prepare(
2871 &self,
2872 _cx: &Cx,
2873 _sql: &str,
2874 ) -> impl Future<Output = Outcome<sqlmodel_core::connection::PreparedStatement, Error>> + Send
2875 {
2876 async {
2877 Outcome::Ok(sqlmodel_core::connection::PreparedStatement::new(
2878 0,
2879 String::new(),
2880 0,
2881 ))
2882 }
2883 }
2884
2885 fn query_prepared(
2886 &self,
2887 _cx: &Cx,
2888 _stmt: &sqlmodel_core::connection::PreparedStatement,
2889 _params: &[Value],
2890 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
2891 async { Outcome::Ok(vec![]) }
2892 }
2893
2894 fn execute_prepared(
2895 &self,
2896 _cx: &Cx,
2897 _stmt: &sqlmodel_core::connection::PreparedStatement,
2898 _params: &[Value],
2899 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
2900 async { Outcome::Ok(0) }
2901 }
2902
2903 fn ping(&self, _cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
2904 async { Outcome::Ok(()) }
2905 }
2906
2907 fn close(self, _cx: &Cx) -> impl Future<Output = sqlmodel_core::Result<()>> + Send {
2908 async { Ok(()) }
2909 }
2910 }
2911
2912 struct MockTransaction;
2913
2914 impl sqlmodel_core::connection::TransactionOps for MockTransaction {
2915 fn query(
2916 &self,
2917 _cx: &Cx,
2918 _sql: &str,
2919 _params: &[Value],
2920 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
2921 async { Outcome::Ok(vec![]) }
2922 }
2923
2924 fn query_one(
2925 &self,
2926 _cx: &Cx,
2927 _sql: &str,
2928 _params: &[Value],
2929 ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
2930 async { Outcome::Ok(None) }
2931 }
2932
2933 fn execute(
2934 &self,
2935 _cx: &Cx,
2936 _sql: &str,
2937 _params: &[Value],
2938 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
2939 async { Outcome::Ok(0) }
2940 }
2941
2942 fn savepoint(
2943 &self,
2944 _cx: &Cx,
2945 _name: &str,
2946 ) -> impl Future<Output = Outcome<(), Error>> + Send {
2947 async { Outcome::Ok(()) }
2948 }
2949
2950 fn rollback_to(
2951 &self,
2952 _cx: &Cx,
2953 _name: &str,
2954 ) -> impl Future<Output = Outcome<(), Error>> + Send {
2955 async { Outcome::Ok(()) }
2956 }
2957
2958 fn release(
2959 &self,
2960 _cx: &Cx,
2961 _name: &str,
2962 ) -> impl Future<Output = Outcome<(), Error>> + Send {
2963 async { Outcome::Ok(()) }
2964 }
2965
2966 fn commit(self, _cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
2967 async { Outcome::Ok(()) }
2968 }
2969
2970 fn rollback(self, _cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
2971 async { Outcome::Ok(()) }
2972 }
2973 }
2974
2975 #[test]
2976 fn test_load_many_single_query_and_populates_lazy() {
2977 let rt = RuntimeBuilder::current_thread()
2978 .build()
2979 .expect("create asupersync runtime");
2980 let cx = Cx::for_testing();
2981
2982 let state = Arc::new(Mutex::new(MockState::default()));
2983 let conn = MockConnection {
2984 state: Arc::clone(&state),
2985 };
2986 let mut session = Session::new(conn);
2987
2988 let heroes = vec![
2989 Hero {
2990 id: Some(1),
2991 team: Lazy::from_fk(1_i64),
2992 },
2993 Hero {
2994 id: Some(2),
2995 team: Lazy::from_fk(2_i64),
2996 },
2997 Hero {
2998 id: Some(3),
2999 team: Lazy::from_fk(1_i64),
3000 },
3001 Hero {
3002 id: Some(4),
3003 team: Lazy::empty(),
3004 },
3005 Hero {
3006 id: Some(5),
3007 team: Lazy::from_fk(999_i64),
3008 },
3009 ];
3010
3011 rt.block_on(async {
3012 let loaded = unwrap_outcome(
3013 session
3014 .load_many::<Hero, Team, _>(&cx, &heroes, |h| &h.team)
3015 .await,
3016 );
3017 assert_eq!(loaded, 3);
3018
3019 assert!(heroes[0].team.is_loaded());
3021 assert_eq!(heroes[0].team.get().unwrap().name, "Avengers");
3022 assert_eq!(heroes[1].team.get().unwrap().name, "X-Men");
3023 assert_eq!(heroes[2].team.get().unwrap().name, "Avengers");
3024
3025 assert!(heroes[3].team.is_loaded());
3027 assert!(heroes[3].team.get().is_none());
3028
3029 assert!(heroes[4].team.is_loaded());
3031 assert!(heroes[4].team.get().is_none());
3032
3033 let team1 = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await);
3035 assert_eq!(
3036 team1,
3037 Some(Team {
3038 id: Some(1),
3039 name: "Avengers".to_string()
3040 })
3041 );
3042 });
3043
3044 assert_eq!(state.lock().expect("lock poisoned").query_calls, 1);
3045 }
3046
3047 #[test]
3048 fn test_add_all_with_vec() {
3049 let state = Arc::new(Mutex::new(MockState::default()));
3050 let conn = MockConnection {
3051 state: Arc::clone(&state),
3052 };
3053 let mut session = Session::new(conn);
3054
3055 let teams = vec![
3058 Team {
3059 id: Some(100),
3060 name: "Team A".to_string(),
3061 },
3062 Team {
3063 id: Some(101),
3064 name: "Team B".to_string(),
3065 },
3066 Team {
3067 id: Some(102),
3068 name: "Team C".to_string(),
3069 },
3070 ];
3071
3072 session.add_all(&teams);
3073
3074 let info = session.debug_state();
3075 assert_eq!(info.pending_new, 3);
3076 assert_eq!(info.tracked, 3);
3077 }
3078
3079 #[test]
3080 fn test_add_all_with_empty_collection() {
3081 let state = Arc::new(Mutex::new(MockState::default()));
3082 let conn = MockConnection {
3083 state: Arc::clone(&state),
3084 };
3085 let mut session = Session::new(conn);
3086
3087 let teams: Vec<Team> = vec![];
3088 session.add_all(&teams);
3089
3090 let info = session.debug_state();
3091 assert_eq!(info.pending_new, 0);
3092 assert_eq!(info.tracked, 0);
3093 }
3094
3095 #[test]
3096 fn test_add_all_with_iterator() {
3097 let state = Arc::new(Mutex::new(MockState::default()));
3098 let conn = MockConnection {
3099 state: Arc::clone(&state),
3100 };
3101 let mut session = Session::new(conn);
3102
3103 let teams = [
3104 Team {
3105 id: Some(200),
3106 name: "Team X".to_string(),
3107 },
3108 Team {
3109 id: Some(201),
3110 name: "Team Y".to_string(),
3111 },
3112 ];
3113
3114 session.add_all(teams.iter());
3116
3117 let info = session.debug_state();
3118 assert_eq!(info.pending_new, 2);
3119 assert_eq!(info.tracked, 2);
3120 }
3121
3122 #[test]
3123 fn test_add_all_with_slice() {
3124 let state = Arc::new(Mutex::new(MockState::default()));
3125 let conn = MockConnection {
3126 state: Arc::clone(&state),
3127 };
3128 let mut session = Session::new(conn);
3129
3130 let teams = [
3131 Team {
3132 id: Some(300),
3133 name: "Team 1".to_string(),
3134 },
3135 Team {
3136 id: Some(301),
3137 name: "Team 2".to_string(),
3138 },
3139 ];
3140
3141 session.add_all(&teams);
3142
3143 let info = session.debug_state();
3144 assert_eq!(info.pending_new, 2);
3145 assert_eq!(info.tracked, 2);
3146 }
3147
3148 #[test]
3151 fn test_merge_new_object_without_load() {
3152 let rt = RuntimeBuilder::current_thread()
3153 .build()
3154 .expect("create asupersync runtime");
3155 let cx = Cx::for_testing();
3156
3157 let state = Arc::new(Mutex::new(MockState::default()));
3158 let conn = MockConnection {
3159 state: Arc::clone(&state),
3160 };
3161 let mut session = Session::new(conn);
3162
3163 rt.block_on(async {
3164 let team = Team {
3166 id: Some(100),
3167 name: "New Team".to_string(),
3168 };
3169
3170 let merged = unwrap_outcome(session.merge(&cx, team.clone(), false).await);
3171
3172 assert_eq!(merged.id, Some(100));
3174 assert_eq!(merged.name, "New Team");
3175
3176 let info = session.debug_state();
3178 assert_eq!(info.pending_new, 1);
3179 assert_eq!(info.tracked, 1);
3180 });
3181
3182 assert_eq!(state.lock().expect("lock poisoned").query_calls, 0);
3184 }
3185
3186 #[test]
3187 fn test_merge_updates_existing_tracked_object() {
3188 let rt = RuntimeBuilder::current_thread()
3189 .build()
3190 .expect("create asupersync runtime");
3191 let cx = Cx::for_testing();
3192
3193 let state = Arc::new(Mutex::new(MockState::default()));
3194 let conn = MockConnection {
3195 state: Arc::clone(&state),
3196 };
3197 let mut session = Session::new(conn);
3198
3199 rt.block_on(async {
3200 let original = Team {
3202 id: Some(1),
3203 name: "Original".to_string(),
3204 };
3205 session.add(&original);
3206
3207 let updated = Team {
3209 id: Some(1),
3210 name: "Updated".to_string(),
3211 };
3212
3213 let merged = unwrap_outcome(session.merge(&cx, updated, false).await);
3214
3215 assert_eq!(merged.id, Some(1));
3217 assert_eq!(merged.name, "Updated");
3218
3219 let info = session.debug_state();
3221 assert_eq!(info.tracked, 1);
3222 });
3223 }
3224
3225 #[test]
3226 fn test_merge_with_load_queries_database() {
3227 let rt = RuntimeBuilder::current_thread()
3228 .build()
3229 .expect("create asupersync runtime");
3230 let cx = Cx::for_testing();
3231
3232 let state = Arc::new(Mutex::new(MockState::default()));
3233 let conn = MockConnection {
3234 state: Arc::clone(&state),
3235 };
3236 let mut session = Session::new(conn);
3237
3238 rt.block_on(async {
3239 let detached = Team {
3241 id: Some(1),
3242 name: "Detached Update".to_string(),
3243 };
3244
3245 let merged = unwrap_outcome(session.merge(&cx, detached, true).await);
3246
3247 assert_eq!(merged.id, Some(1));
3249 assert_eq!(merged.name, "Detached Update");
3250
3251 let info = session.debug_state();
3253 assert_eq!(info.tracked, 1);
3254 assert_eq!(info.pending_dirty, 1);
3255 });
3256
3257 assert_eq!(state.lock().expect("lock poisoned").query_calls, 1);
3259 }
3260
3261 #[test]
3262 fn test_merge_with_load_not_found_creates_new() {
3263 let rt = RuntimeBuilder::current_thread()
3264 .build()
3265 .expect("create asupersync runtime");
3266 let cx = Cx::for_testing();
3267
3268 let state = Arc::new(Mutex::new(MockState::default()));
3269 let conn = MockConnection {
3270 state: Arc::clone(&state),
3271 };
3272 let mut session = Session::new(conn);
3273
3274 rt.block_on(async {
3275 let detached = Team {
3277 id: Some(999),
3278 name: "Not In DB".to_string(),
3279 };
3280
3281 let merged = unwrap_outcome(session.merge(&cx, detached, true).await);
3282
3283 assert_eq!(merged.id, Some(999));
3285 assert_eq!(merged.name, "Not In DB");
3286
3287 let info = session.debug_state();
3289 assert_eq!(info.pending_new, 1);
3290 assert_eq!(info.tracked, 1);
3291 });
3292
3293 assert_eq!(state.lock().expect("lock poisoned").query_calls, 1);
3295 }
3296
3297 #[test]
3298 fn test_merge_without_load_convenience() {
3299 let rt = RuntimeBuilder::current_thread()
3300 .build()
3301 .expect("create asupersync runtime");
3302 let cx = Cx::for_testing();
3303
3304 let state = Arc::new(Mutex::new(MockState::default()));
3305 let conn = MockConnection {
3306 state: Arc::clone(&state),
3307 };
3308 let mut session = Session::new(conn);
3309
3310 rt.block_on(async {
3311 let team = Team {
3312 id: Some(42),
3313 name: "Test".to_string(),
3314 };
3315
3316 let merged = unwrap_outcome(session.merge_without_load(&cx, team).await);
3318
3319 assert_eq!(merged.id, Some(42));
3320 assert_eq!(merged.name, "Test");
3321
3322 let info = session.debug_state();
3323 assert_eq!(info.pending_new, 1);
3324 });
3325
3326 assert_eq!(state.lock().expect("lock poisoned").query_calls, 0);
3328 }
3329
3330 #[test]
3331 fn test_merge_null_pk_treated_as_new() {
3332 let rt = RuntimeBuilder::current_thread()
3333 .build()
3334 .expect("create asupersync runtime");
3335 let cx = Cx::for_testing();
3336
3337 let state = Arc::new(Mutex::new(MockState::default()));
3338 let conn = MockConnection {
3339 state: Arc::clone(&state),
3340 };
3341 let mut session = Session::new(conn);
3342
3343 rt.block_on(async {
3344 let new_team = Team {
3346 id: None,
3347 name: "Brand New".to_string(),
3348 };
3349
3350 let merged = unwrap_outcome(session.merge(&cx, new_team, true).await);
3351
3352 assert_eq!(merged.id, None);
3354 assert_eq!(merged.name, "Brand New");
3355
3356 let info = session.debug_state();
3358 assert_eq!(info.pending_new, 1);
3359 });
3360
3361 assert_eq!(state.lock().expect("lock poisoned").query_calls, 0);
3363 }
3364
3365 #[test]
3368 fn test_is_modified_new_object_returns_true() {
3369 let state = Arc::new(Mutex::new(MockState::default()));
3370 let conn = MockConnection {
3371 state: Arc::clone(&state),
3372 };
3373 let mut session = Session::new(conn);
3374
3375 let team = Team {
3376 id: Some(100),
3377 name: "New Team".to_string(),
3378 };
3379
3380 session.add(&team);
3382 assert!(session.is_modified(&team));
3383 }
3384
3385 #[test]
3386 fn test_is_modified_untracked_returns_false() {
3387 let state = Arc::new(Mutex::new(MockState::default()));
3388 let conn = MockConnection {
3389 state: Arc::clone(&state),
3390 };
3391 let session = Session::<MockConnection>::new(conn);
3392
3393 let team = Team {
3394 id: Some(100),
3395 name: "Not Tracked".to_string(),
3396 };
3397
3398 assert!(!session.is_modified(&team));
3400 }
3401
3402 #[test]
3403 fn test_is_modified_after_load_returns_false() {
3404 let rt = RuntimeBuilder::current_thread()
3405 .build()
3406 .expect("create asupersync runtime");
3407 let cx = Cx::for_testing();
3408
3409 let state = Arc::new(Mutex::new(MockState::default()));
3410 let conn = MockConnection {
3411 state: Arc::clone(&state),
3412 };
3413 let mut session = Session::new(conn);
3414
3415 rt.block_on(async {
3416 let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3418
3419 assert!(!session.is_modified(&team));
3421 });
3422 }
3423
3424 #[test]
3425 fn test_is_modified_after_mark_dirty_returns_true() {
3426 let rt = RuntimeBuilder::current_thread()
3427 .build()
3428 .expect("create asupersync runtime");
3429 let cx = Cx::for_testing();
3430
3431 let state = Arc::new(Mutex::new(MockState::default()));
3432 let conn = MockConnection {
3433 state: Arc::clone(&state),
3434 };
3435 let mut session = Session::new(conn);
3436
3437 rt.block_on(async {
3438 let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3440 assert!(!session.is_modified(&team));
3441
3442 let mut modified_team = team.clone();
3444 modified_team.name = "Modified Name".to_string();
3445 session.mark_dirty(&modified_team);
3446
3447 assert!(session.is_modified(&modified_team));
3449 });
3450 }
3451
3452 #[test]
3453 fn test_is_modified_deleted_returns_true() {
3454 let rt = RuntimeBuilder::current_thread()
3455 .build()
3456 .expect("create asupersync runtime");
3457 let cx = Cx::for_testing();
3458
3459 let state = Arc::new(Mutex::new(MockState::default()));
3460 let conn = MockConnection {
3461 state: Arc::clone(&state),
3462 };
3463 let mut session = Session::new(conn);
3464
3465 rt.block_on(async {
3466 let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3468 assert!(!session.is_modified(&team));
3469
3470 session.delete(&team);
3472
3473 assert!(session.is_modified(&team));
3475 });
3476 }
3477
3478 #[test]
3479 fn test_is_modified_detached_returns_false() {
3480 let rt = RuntimeBuilder::current_thread()
3481 .build()
3482 .expect("create asupersync runtime");
3483 let cx = Cx::for_testing();
3484
3485 let state = Arc::new(Mutex::new(MockState::default()));
3486 let conn = MockConnection {
3487 state: Arc::clone(&state),
3488 };
3489 let mut session = Session::new(conn);
3490
3491 rt.block_on(async {
3492 let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3494
3495 session.expunge(&team);
3497
3498 assert!(!session.is_modified(&team));
3500 });
3501 }
3502
3503 #[test]
3504 fn test_object_state_returns_correct_state() {
3505 let rt = RuntimeBuilder::current_thread()
3506 .build()
3507 .expect("create asupersync runtime");
3508 let cx = Cx::for_testing();
3509
3510 let state = Arc::new(Mutex::new(MockState::default()));
3511 let conn = MockConnection {
3512 state: Arc::clone(&state),
3513 };
3514 let mut session = Session::new(conn);
3515
3516 let untracked = Team {
3518 id: Some(999),
3519 name: "Untracked".to_string(),
3520 };
3521 assert_eq!(session.object_state(&untracked), None);
3522
3523 let new_team = Team {
3525 id: Some(100),
3526 name: "New".to_string(),
3527 };
3528 session.add(&new_team);
3529 assert_eq!(session.object_state(&new_team), Some(ObjectState::New));
3530
3531 rt.block_on(async {
3532 let persistent = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3534 assert_eq!(
3535 session.object_state(&persistent),
3536 Some(ObjectState::Persistent)
3537 );
3538
3539 session.delete(&persistent);
3541 assert_eq!(
3542 session.object_state(&persistent),
3543 Some(ObjectState::Deleted)
3544 );
3545 });
3546 }
3547
3548 #[test]
3549 fn test_modified_attributes_returns_changed_columns() {
3550 let rt = RuntimeBuilder::current_thread()
3551 .build()
3552 .expect("create asupersync runtime");
3553 let cx = Cx::for_testing();
3554
3555 let state = Arc::new(Mutex::new(MockState::default()));
3556 let conn = MockConnection {
3557 state: Arc::clone(&state),
3558 };
3559 let mut session = Session::new(conn);
3560
3561 rt.block_on(async {
3562 let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3564
3565 let modified = session.modified_attributes(&team);
3567 assert!(modified.is_empty());
3568
3569 let mut modified_team = team.clone();
3571 modified_team.name = "Changed Name".to_string();
3572 session.mark_dirty(&modified_team);
3573
3574 let modified = session.modified_attributes(&modified_team);
3576 assert!(modified.contains(&"name"));
3577 });
3578 }
3579
3580 #[test]
3581 fn test_modified_attributes_untracked_returns_empty() {
3582 let state = Arc::new(Mutex::new(MockState::default()));
3583 let conn = MockConnection {
3584 state: Arc::clone(&state),
3585 };
3586 let session = Session::<MockConnection>::new(conn);
3587
3588 let team = Team {
3589 id: Some(100),
3590 name: "Not Tracked".to_string(),
3591 };
3592
3593 let modified = session.modified_attributes(&team);
3594 assert!(modified.is_empty());
3595 }
3596
3597 #[test]
3598 fn test_modified_attributes_new_returns_empty() {
3599 let state = Arc::new(Mutex::new(MockState::default()));
3600 let conn = MockConnection {
3601 state: Arc::clone(&state),
3602 };
3603 let mut session = Session::new(conn);
3604
3605 let team = Team {
3606 id: Some(100),
3607 name: "New".to_string(),
3608 };
3609 session.add(&team);
3610
3611 let modified = session.modified_attributes(&team);
3613 assert!(modified.is_empty());
3614 }
3615
3616 #[test]
3619 fn test_expire_marks_object_as_expired() {
3620 let rt = RuntimeBuilder::current_thread()
3621 .build()
3622 .expect("create asupersync runtime");
3623 let cx = Cx::for_testing();
3624
3625 let state = Arc::new(Mutex::new(MockState::default()));
3626 let conn = MockConnection {
3627 state: Arc::clone(&state),
3628 };
3629 let mut session = Session::new(conn);
3630
3631 rt.block_on(async {
3632 let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await);
3634 assert!(team.is_some());
3635 let team = team.unwrap();
3636
3637 assert!(!session.is_expired(&team));
3639 assert_eq!(session.object_state(&team), Some(ObjectState::Persistent));
3640
3641 session.expire(&team, None);
3643
3644 assert!(session.is_expired(&team));
3646 assert_eq!(session.object_state(&team), Some(ObjectState::Expired));
3647 });
3648 }
3649
3650 #[test]
3651 fn test_expire_specific_attributes() {
3652 let rt = RuntimeBuilder::current_thread()
3653 .build()
3654 .expect("create asupersync runtime");
3655 let cx = Cx::for_testing();
3656
3657 let state = Arc::new(Mutex::new(MockState::default()));
3658 let conn = MockConnection {
3659 state: Arc::clone(&state),
3660 };
3661 let mut session = Session::new(conn);
3662
3663 rt.block_on(async {
3664 let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3666
3667 session.expire(&team, Some(&["name"]));
3669
3670 assert!(session.is_expired(&team));
3672
3673 let expired = session.expired_attributes(&team);
3675 assert!(expired.is_some());
3676 let expired_set = expired.unwrap();
3677 assert!(expired_set.is_some());
3678 assert!(expired_set.unwrap().contains("name"));
3679 });
3680 }
3681
3682 #[test]
3683 fn test_expire_all_marks_all_objects_expired() {
3684 let rt = RuntimeBuilder::current_thread()
3685 .build()
3686 .expect("create asupersync runtime");
3687 let cx = Cx::for_testing();
3688
3689 let state = Arc::new(Mutex::new(MockState::default()));
3690 let conn = MockConnection {
3691 state: Arc::clone(&state),
3692 };
3693 let mut session = Session::new(conn);
3694
3695 rt.block_on(async {
3696 let team1 = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3698 let team2 = unwrap_outcome(session.get::<Team>(&cx, 2_i64).await).unwrap();
3699
3700 assert!(!session.is_expired(&team1));
3702 assert!(!session.is_expired(&team2));
3703
3704 session.expire_all();
3706
3707 assert!(session.is_expired(&team1));
3709 assert!(session.is_expired(&team2));
3710 });
3711 }
3712
3713 #[test]
3714 fn test_expire_does_not_affect_new_objects() {
3715 let state = Arc::new(Mutex::new(MockState::default()));
3716 let conn = MockConnection {
3717 state: Arc::clone(&state),
3718 };
3719 let mut session = Session::new(conn);
3720
3721 let team = Team {
3723 id: Some(100),
3724 name: "New Team".to_string(),
3725 };
3726 session.add(&team);
3727
3728 session.expire(&team, None);
3730
3731 assert_eq!(session.object_state(&team), Some(ObjectState::New));
3733 assert!(!session.is_expired(&team));
3734 }
3735
3736 #[test]
3737 fn test_expired_object_reloads_on_get() {
3738 let rt = RuntimeBuilder::current_thread()
3739 .build()
3740 .expect("create asupersync runtime");
3741 let cx = Cx::for_testing();
3742
3743 let state = Arc::new(Mutex::new(MockState::default()));
3744 let conn = MockConnection {
3745 state: Arc::clone(&state),
3746 };
3747 let mut session = Session::new(conn);
3748
3749 rt.block_on(async {
3750 let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3752 assert_eq!(team.name, "Avengers");
3753
3754 let team2 = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3756 assert_eq!(team2.name, "Avengers");
3757
3758 {
3760 let s = state.lock().expect("lock poisoned");
3761 assert_eq!(s.query_calls, 1);
3762 }
3763
3764 session.expire(&team, None);
3766
3767 let team3 = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3769 assert_eq!(team3.name, "Avengers");
3770
3771 {
3773 let s = state.lock().expect("lock poisoned");
3774 assert_eq!(s.query_calls, 2);
3775 }
3776
3777 assert!(!session.is_expired(&team3));
3779 assert_eq!(session.object_state(&team3), Some(ObjectState::Persistent));
3780 });
3781 }
3782
3783 #[test]
3784 fn test_is_expired_returns_false_for_untracked() {
3785 let state = Arc::new(Mutex::new(MockState::default()));
3786 let conn = MockConnection {
3787 state: Arc::clone(&state),
3788 };
3789 let session = Session::<MockConnection>::new(conn);
3790
3791 let team = Team {
3792 id: Some(999),
3793 name: "Not Tracked".to_string(),
3794 };
3795
3796 assert!(!session.is_expired(&team));
3798 }
3799
3800 #[test]
3801 fn test_expired_attributes_returns_none_for_persistent() {
3802 let rt = RuntimeBuilder::current_thread()
3803 .build()
3804 .expect("create asupersync runtime");
3805 let cx = Cx::for_testing();
3806
3807 let state = Arc::new(Mutex::new(MockState::default()));
3808 let conn = MockConnection {
3809 state: Arc::clone(&state),
3810 };
3811 let mut session = Session::new(conn);
3812
3813 rt.block_on(async {
3814 let team = unwrap_outcome(session.get::<Team>(&cx, 1_i64).await).unwrap();
3816
3817 let expired = session.expired_attributes(&team);
3819 assert!(expired.is_none());
3820 });
3821 }
3822}