Skip to main content

sqlmodel_core/
tracked.rs

1//! Tracked model wrapper for Pydantic-compatible `exclude_unset`.
2//!
3//! Rust structs do not retain "field was explicitly provided" metadata by default.
4//! `TrackedModel<T>` stores a `FieldsSet` alongside `T` so dumps can implement
5//! Pydantic's `exclude_unset` semantics precisely.
6
7use std::ops::{Deref, DerefMut};
8
9use crate::fields_set::FieldsSet;
10use crate::validate::{DumpOptions, DumpResult, apply_serialization_aliases};
11use crate::{FieldInfo, Model};
12
13/// A model instance with explicit "fields set" tracking.
14#[derive(Clone, Debug)]
15pub struct TrackedModel<T> {
16    inner: T,
17    fields_set: FieldsSet,
18}
19
20impl<T> TrackedModel<T> {
21    /// Wrap an instance and an explicit FieldsSet.
22    #[must_use]
23    pub const fn new(inner: T, fields_set: FieldsSet) -> Self {
24        Self { inner, fields_set }
25    }
26
27    /// Borrow the wrapped instance.
28    #[must_use]
29    pub const fn inner(&self) -> &T {
30        &self.inner
31    }
32
33    /// Mutably borrow the wrapped instance.
34    #[must_use]
35    pub fn inner_mut(&mut self) -> &mut T {
36        &mut self.inner
37    }
38
39    /// Consume and return the wrapped instance.
40    #[must_use]
41    pub fn into_inner(self) -> T {
42        self.inner
43    }
44
45    /// Access the explicit fields-set bitset.
46    #[must_use]
47    pub const fn fields_set(&self) -> &FieldsSet {
48        &self.fields_set
49    }
50}
51
52impl<T: Model> TrackedModel<T> {
53    /// Wrap an instance and mark all model fields as set.
54    #[must_use]
55    pub fn all_fields_set(inner: T) -> Self {
56        let fields_set = FieldsSet::all(T::fields().len());
57        Self { inner, fields_set }
58    }
59
60    /// Wrap an instance, marking only the provided field names as "set".
61    ///
62    /// Field names must match `FieldInfo.name` values (post-alias).
63    #[must_use]
64    pub fn from_explicit_field_names(inner: T, names: &[&str]) -> Self {
65        let mut fields_set = FieldsSet::empty(T::fields().len());
66        for (idx, field) in T::fields().iter().enumerate() {
67            if names.contains(&field.name) {
68                fields_set.set(idx);
69            }
70        }
71        Self { inner, fields_set }
72    }
73
74    fn apply_field_exclusions(
75        map: &mut serde_json::Map<String, serde_json::Value>,
76        fields: &[FieldInfo],
77        fields_set: &FieldsSet,
78        exclude_unset: bool,
79        exclude_computed_fields: bool,
80        exclude_defaults: bool,
81    ) {
82        // Always honor per-field exclude flag (Pydantic Field(exclude=True) semantics).
83        for field in fields {
84            if field.exclude {
85                map.remove(field.name);
86            }
87        }
88
89        if exclude_unset {
90            for (idx, field) in fields.iter().enumerate() {
91                if !fields_set.is_set(idx) {
92                    map.remove(field.name);
93                }
94            }
95        }
96
97        if exclude_computed_fields {
98            for field in fields {
99                if field.computed {
100                    map.remove(field.name);
101                }
102            }
103        }
104
105        if exclude_defaults {
106            for field in fields {
107                if let Some(default_json) = field.default_json {
108                    if let Some(current_value) = map.get(field.name) {
109                        if let Ok(default_value) =
110                            serde_json::from_str::<serde_json::Value>(default_json)
111                        {
112                            if current_value == &default_value {
113                                map.remove(field.name);
114                            }
115                        }
116                    }
117                }
118            }
119        }
120    }
121}
122
123impl<T: Model + serde::Serialize> TrackedModel<T> {
124    /// Model-aware dump with correct `exclude_unset` semantics.
125    ///
126    /// This mirrors `SqlModelDump::sql_model_dump`, but supports `exclude_unset`
127    /// by consulting the stored `FieldsSet`.
128    pub fn sql_model_dump(&self, options: DumpOptions) -> DumpResult {
129        let DumpOptions {
130            include,
131            exclude,
132            by_alias,
133            exclude_unset,
134            exclude_defaults,
135            exclude_none,
136            exclude_computed_fields,
137            mode: _,
138            round_trip: _,
139            indent: _,
140        } = options;
141
142        let mut value = serde_json::to_value(&self.inner)?;
143
144        if let serde_json::Value::Object(ref mut map) = value {
145            Self::apply_field_exclusions(
146                map,
147                T::fields(),
148                &self.fields_set,
149                exclude_unset,
150                exclude_computed_fields,
151                exclude_defaults,
152            );
153        }
154
155        if by_alias {
156            apply_serialization_aliases(&mut value, T::fields());
157        }
158
159        if let serde_json::Value::Object(ref mut map) = value {
160            if let Some(ref include_set) = include {
161                map.retain(|k, _| include_set.contains(k));
162            }
163            if let Some(ref exclude_set) = exclude {
164                map.retain(|k, _| !exclude_set.contains(k));
165            }
166            if exclude_none {
167                map.retain(|_, v| !v.is_null());
168            }
169        }
170
171        Ok(value)
172    }
173
174    pub fn sql_model_dump_json(&self) -> std::result::Result<String, serde_json::Error> {
175        let value = self.sql_model_dump(DumpOptions::default())?;
176        serde_json::to_string(&value)
177    }
178
179    pub fn sql_model_dump_json_pretty(&self) -> std::result::Result<String, serde_json::Error> {
180        let value = self.sql_model_dump(DumpOptions::default())?;
181        serde_json::to_string_pretty(&value)
182    }
183
184    pub fn sql_model_dump_json_with_options(
185        &self,
186        options: DumpOptions,
187    ) -> std::result::Result<String, serde_json::Error> {
188        let DumpOptions { indent, .. } = options.clone();
189        let value = self.sql_model_dump(DumpOptions {
190            indent: None,
191            ..options
192        })?;
193
194        match indent {
195            Some(spaces) => {
196                let indent_bytes = " ".repeat(spaces).into_bytes();
197                let formatter = serde_json::ser::PrettyFormatter::with_indent(&indent_bytes);
198                let mut writer = Vec::new();
199                let mut ser = serde_json::Serializer::with_formatter(&mut writer, formatter);
200                serde::Serialize::serialize(&value, &mut ser)?;
201                String::from_utf8(writer).map_err(|e| {
202                    serde_json::Error::io(std::io::Error::new(
203                        std::io::ErrorKind::InvalidData,
204                        format!("UTF-8 encoding error: {e}"),
205                    ))
206                })
207            }
208            None => serde_json::to_string(&value),
209        }
210    }
211}
212
213impl<T> Deref for TrackedModel<T> {
214    type Target = T;
215    fn deref(&self) -> &Self::Target {
216        &self.inner
217    }
218}
219
220impl<T> DerefMut for TrackedModel<T> {
221    fn deref_mut(&mut self) -> &mut Self::Target {
222        &mut self.inner
223    }
224}