zino_orm/
mutation.rs

1/// Generates SQL `SET` expressions.
2use super::{DatabaseDriver, EncodeColumn, Entity, IntoSqlValue, Schema, query::QueryExt};
3use std::marker::PhantomData;
4use zino_core::{
5    JsonValue, Map,
6    datetime::DateTime,
7    extension::{JsonObjectExt, JsonValueExt},
8    model::{Mutation, Query},
9};
10
11/// A mutation builder for the model entity.
12///
13/// # Examples
14/// ```rust,ignore
15/// use crate::model::{User, UserColumn};
16/// use zino_orm::{MutationBuilder, QueryBuilder, Schema};
17///
18/// let query = QueryBuilder::new()
19///     .primary_key("01936dc6-e48c-7d22-8e69-b29f85682fac")
20///     .and_not_in(UserColumn::Status, ["Deleted", "Locked"])
21///     .build();
22/// let mut mutation = MutationBuilder::<User>::new()
23///     .set(UserColumn::Status, "Active")
24///     .set_now(UserColumn::UpdatedAt)
25///     .inc_one(UserColumn::Version)
26///     .build();
27/// let ctx = User::update_one(&query, &mut mutation).await?;
28/// ```
29#[derive(Debug, Clone)]
30pub struct MutationBuilder<E: Entity> {
31    /// The mutation updates.
32    updates: Map,
33    /// `$inc` operations.
34    inc_ops: Map,
35    /// `$mul` operations.
36    mul_ops: Map,
37    /// `$min` operations.
38    min_ops: Map,
39    /// `$max` operations.
40    max_ops: Map,
41    /// The phantom data.
42    phantom: PhantomData<E>,
43}
44
45impl<E: Entity> MutationBuilder<E> {
46    /// Creates a new instance.
47    #[inline]
48    pub fn new() -> Self {
49        Self {
50            updates: Map::new(),
51            inc_ops: Map::new(),
52            mul_ops: Map::new(),
53            min_ops: Map::new(),
54            max_ops: Map::new(),
55            phantom: PhantomData,
56        }
57    }
58
59    /// Update the values for partial columns.
60    ///
61    /// # Examples
62    /// ```rust,ignore
63    /// use crate::model::{Project, ProjectColumn, Task, TaskColumn};
64    /// use zino_orm::{Aggregation, MutationBuilder, QueryBuilder, Schema};
65    ///
66    /// let project_id = task.project_id();
67    /// let query = QueryBuilder::new()
68    ///     .aggregate(Aggregation::Count(TaskColumn::Id, false), ProjectColumn::NumTasks)
69    ///     .aggregate(Aggregation::Sum(TaskColumn::Manhours), ProjectColumn::TotalManhours)
70    ///     .and_eq(TaskColumn::ProjectId, project_id)
71    ///     .build();
72    /// if let Some(stats_data) = Task::find_one(&query).await? {
73    ///     let query = QueryBuilder::<Project>::new()
74    ///         .primary_key(project_id)
75    ///         .build();
76    ///     let mut mutation = MutationBuilder::<Project>::new()
77    ///         .update_partial(Project::generated_columns(), stats_data)
78    ///         .set_now(ProjectColumn::UpdatedAt)
79    ///         .inc_one(ProjectColumn::Version)
80    ///         .build();
81    ///     Project::update_one(&query, &mut mutation).await?;
82    /// }
83    /// ```
84    #[inline]
85    pub fn update_partial(mut self, cols: &[E::Column], mut data: Map) -> Self {
86        if cfg!(debug_assertions) && cols.is_empty() {
87            tracing::warn!("no columns to be updated");
88        }
89        for col in cols {
90            let field = col.as_ref();
91            if let Some(value) = data.remove(field) {
92                self.updates.upsert(field, value);
93            }
94        }
95        self
96    }
97
98    /// Sets the value of a column.
99    #[inline]
100    pub fn set(mut self, col: E::Column, value: impl IntoSqlValue) -> Self {
101        self.updates.upsert(col.as_ref(), value.into_sql_value());
102        self
103    }
104
105    /// Sets the value of a column if the value is not null.
106    #[inline]
107    pub fn set_if_not_null(mut self, col: E::Column, value: impl IntoSqlValue) -> Self {
108        let value = value.into_sql_value();
109        if !value.is_null() {
110            self.updates.upsert(col.as_ref(), value);
111        }
112        self
113    }
114
115    /// Sets the value of a column if the value is not empty or null.
116    #[inline]
117    pub fn set_if_nonempty(mut self, col: E::Column, value: impl IntoSqlValue) -> Self {
118        let value = value.into_sql_value();
119        if !value.is_ignorable() {
120            self.updates.upsert(col.as_ref(), value);
121        }
122        self
123    }
124
125    /// Sets the value of a column if the value is not none.
126    #[inline]
127    pub fn set_if_some<T: IntoSqlValue>(mut self, col: E::Column, value: Option<T>) -> Self {
128        if let Some(value) = value {
129            self.updates.upsert(col.as_ref(), value.into_sql_value());
130        }
131        self
132    }
133
134    /// Sets the value of a column to null.
135    #[inline]
136    pub fn set_null(mut self, col: E::Column) -> Self {
137        self.updates.upsert(col.as_ref(), JsonValue::Null);
138        self
139    }
140
141    /// Sets the value of a column to the current date time.
142    #[inline]
143    pub fn set_now(mut self, col: E::Column) -> Self {
144        self.updates
145            .upsert(col.as_ref(), DateTime::now().into_sql_value());
146        self
147    }
148
149    /// Increments the value of a column.
150    #[inline]
151    pub fn inc(mut self, col: E::Column, value: impl IntoSqlValue) -> Self {
152        self.inc_ops.upsert(col.as_ref(), value.into_sql_value());
153        self
154    }
155
156    /// Increments the value of a column by 1.
157    #[inline]
158    pub fn inc_one(mut self, col: E::Column) -> Self {
159        self.inc_ops.upsert(col.as_ref(), 1);
160        self
161    }
162
163    /// Multiplies the value of a column by a number.
164    #[inline]
165    pub fn mul(mut self, col: E::Column, value: impl IntoSqlValue) -> Self {
166        self.mul_ops.upsert(col.as_ref(), value.into_sql_value());
167        self
168    }
169
170    /// Updates the value of a column to a specified value
171    /// if the specified value is less than the current value of the column.
172    #[inline]
173    pub fn min(mut self, col: E::Column, value: impl IntoSqlValue) -> Self {
174        self.min_ops.upsert(col.as_ref(), value.into_sql_value());
175        self
176    }
177
178    /// Updates the value of a column to a specified value
179    /// if the specified value is greater than the current value of the column.
180    #[inline]
181    pub fn max(mut self, col: E::Column, value: impl IntoSqlValue) -> Self {
182        self.max_ops.upsert(col.as_ref(), value.into_sql_value());
183        self
184    }
185
186    /// Builds the model mutation.
187    pub fn build(self) -> Mutation {
188        let mut updates = self.updates;
189        let inc_ops = self.inc_ops;
190        let mul_ops = self.mul_ops;
191        let min_ops = self.min_ops;
192        let max_ops = self.max_ops;
193        if !inc_ops.is_empty() {
194            updates.upsert("$inc", inc_ops);
195        }
196        if !mul_ops.is_empty() {
197            updates.upsert("$mul", mul_ops);
198        }
199        if !min_ops.is_empty() {
200            updates.upsert("$min", min_ops);
201        }
202        if !max_ops.is_empty() {
203            updates.upsert("$max", max_ops);
204        }
205        Mutation::new(updates)
206    }
207}
208
209impl<E: Entity> Default for MutationBuilder<E> {
210    #[inline]
211    fn default() -> Self {
212        Self::new()
213    }
214}
215
216/// Extension trait for [`Mutation`](crate::model::Mutation).
217pub(super) trait MutationExt<DB> {
218    /// Formats the updates to generate SQL `SET` expression.
219    fn format_updates<M: Schema>(&self) -> String;
220}
221
222impl MutationExt<DatabaseDriver> for Mutation {
223    fn format_updates<M: Schema>(&self) -> String {
224        let updates = self.updates();
225        if updates.is_empty() {
226            return String::new();
227        }
228
229        let fields = self.fields();
230        let permissive = fields.is_empty();
231        let mut mutations = Vec::new();
232        for (key, value) in updates.iter() {
233            match key.as_str() {
234                "$inc" => {
235                    if let Some(update) = value.as_object() {
236                        for (key, value) in update.iter() {
237                            if (permissive || fields.contains(key))
238                                && let Some(col) = M::get_writable_column(key)
239                            {
240                                let key = Query::format_field(key);
241                                let value = col.encode_value(Some(value));
242                                let mutation = format!(r#"{key} = {value} + {key}"#);
243                                mutations.push(mutation);
244                            }
245                        }
246                    }
247                }
248                "$mul" => {
249                    if let Some(update) = value.as_object() {
250                        for (key, value) in update.iter() {
251                            if (permissive || fields.contains(key))
252                                && let Some(col) = M::get_writable_column(key)
253                            {
254                                let key = Query::format_field(key);
255                                let value = col.encode_value(Some(value));
256                                let mutation = format!(r#"{key} = {value} * {key}"#);
257                                mutations.push(mutation);
258                            }
259                        }
260                    }
261                }
262                "$min" => {
263                    if let Some(update) = value.as_object() {
264                        for (key, value) in update.iter() {
265                            if (permissive || fields.contains(key))
266                                && let Some(col) = M::get_writable_column(key)
267                            {
268                                let key = Query::format_field(key);
269                                let value = col.encode_value(Some(value));
270                                let mutation = if cfg!(feature = "orm-sqlite") {
271                                    format!(r#"{key} = MIN({value}, {key})"#)
272                                } else {
273                                    format!(r#"{key} = LEAST({value}, {key})"#)
274                                };
275                                mutations.push(mutation);
276                            }
277                        }
278                    }
279                }
280                "$max" => {
281                    if let Some(update) = value.as_object() {
282                        for (key, value) in update.iter() {
283                            if (permissive || fields.contains(key))
284                                && let Some(col) = M::get_writable_column(key)
285                            {
286                                let key = Query::format_field(key);
287                                let value = col.encode_value(Some(value));
288                                let mutation = if cfg!(feature = "orm-sqlite") {
289                                    format!(r#"{key} = MAX({value}, {key})"#)
290                                } else {
291                                    format!(r#"{key} = GREATEST({value}, {key})"#)
292                                };
293                                mutations.push(mutation);
294                            }
295                        }
296                    }
297                }
298                _ => {
299                    if (permissive || fields.contains(key))
300                        && let Some(col) = M::get_writable_column(key)
301                    {
302                        let key = Query::format_field(key);
303                        let mutation = if let Some(subquery) =
304                            value.as_object().and_then(|m| m.get_str("$subquery"))
305                        {
306                            format!(r#"{key} = {subquery}"#)
307                        } else {
308                            let value = col.encode_value(Some(value));
309                            format!(r#"{key} = {value}"#)
310                        };
311                        mutations.push(mutation);
312                    }
313                }
314            }
315        }
316        mutations.join(", ")
317    }
318}