Skip to main content

sea_orm/executor/
update.rs

1use super::ReturningSelector;
2use crate::{
3    ActiveModelTrait, ColumnTrait, ConnectionTrait, EntityTrait, IntoActiveModel, Iterable,
4    PrimaryKeyTrait, SelectModel, UpdateMany, UpdateOne, ValidatedUpdateOne, error::*,
5};
6use sea_query::{FromValueTuple, Query, UpdateStatement};
7
8/// Lower-level executor that runs a raw `sea_query` [`UpdateStatement`].
9/// Most code shouldn't need it directly — prefer
10/// [`EntityTrait::update`](crate::EntityTrait::update) /
11/// [`update_many`](crate::EntityTrait::update_many), which return
12/// strongly-typed builders.
13#[derive(Clone, Debug)]
14pub struct Updater {
15    query: UpdateStatement,
16    check_record_exists: bool,
17}
18
19/// Result of an `UPDATE` that doesn't return rows: how many rows were
20/// modified.
21#[derive(Clone, Debug, PartialEq, Eq, Default)]
22pub struct UpdateResult {
23    /// Number of rows touched by the statement.
24    pub rows_affected: u64,
25}
26
27impl<A> ValidatedUpdateOne<A>
28where
29    A: ActiveModelTrait,
30{
31    /// Execute an UPDATE operation without a RETURNING clause, yielding an
32    /// [`UpdateResult`] instead of the updated model. Returns
33    /// [`DbErr::RecordNotUpdated`] if no row matches.
34    pub async fn exec_without_returning<C>(self, db: &C) -> Result<UpdateResult, DbErr>
35    where
36        C: ConnectionTrait,
37    {
38        Updater::new(self.query)
39            // If nothing is updated, return RecordNotUpdated error
40            .check_record_exists()
41            .exec(db)
42            .await
43    }
44
45    /// Execute an UPDATE operation on an ActiveModel
46    pub async fn exec<C>(self, db: &C) -> Result<<A::Entity as EntityTrait>::Model, DbErr>
47    where
48        <A::Entity as EntityTrait>::Model: IntoActiveModel<A>,
49        C: ConnectionTrait,
50    {
51        Updater::new(self.query)
52            .exec_update_and_return_updated(self.model, db)
53            .await
54    }
55}
56
57impl<A> UpdateOne<A>
58where
59    A: ActiveModelTrait,
60{
61    /// Execute an UPDATE operation without a RETURNING clause, yielding an
62    /// [`UpdateResult`] instead of the updated model. Returns
63    /// [`DbErr::RecordNotUpdated`] if no row matches.
64    pub async fn exec_without_returning<C>(self, db: &C) -> Result<UpdateResult, DbErr>
65    where
66        C: ConnectionTrait,
67    {
68        self.0?.exec_without_returning(db).await
69    }
70
71    /// Execute an UPDATE operation on an ActiveModel
72    pub async fn exec<C>(self, db: &C) -> Result<<A::Entity as EntityTrait>::Model, DbErr>
73    where
74        <A::Entity as EntityTrait>::Model: IntoActiveModel<A>,
75        C: ConnectionTrait,
76    {
77        self.0?.exec(db).await
78    }
79}
80
81impl<'a, E> UpdateMany<E>
82where
83    E: EntityTrait,
84{
85    /// Execute an update operation on multiple ActiveModels
86    pub async fn exec<C>(self, db: &'a C) -> Result<UpdateResult, DbErr>
87    where
88        C: ConnectionTrait,
89    {
90        Updater::new(self.query).exec(db).await
91    }
92
93    /// Execute an update operation and return the updated model (use `RETURNING` syntax if supported)
94    pub async fn exec_with_returning<C>(self, db: &'a C) -> Result<Vec<E::Model>, DbErr>
95    where
96        C: ConnectionTrait,
97    {
98        Updater::new(self.query)
99            .exec_update_with_returning::<E, _>(db)
100            .await
101    }
102}
103
104impl Updater {
105    /// Instantiate an update using an [UpdateStatement]
106    fn new(query: UpdateStatement) -> Self {
107        Self {
108            query,
109            check_record_exists: false,
110        }
111    }
112
113    fn check_record_exists(mut self) -> Self {
114        self.check_record_exists = true;
115        self
116    }
117
118    /// Execute an update operation
119    pub async fn exec<C>(self, db: &C) -> Result<UpdateResult, DbErr>
120    where
121        C: ConnectionTrait,
122    {
123        if self.is_noop() {
124            return Ok(UpdateResult::default());
125        }
126        let result = db.execute(&self.query).await?;
127        if self.check_record_exists && result.rows_affected() == 0 {
128            return Err(DbErr::RecordNotUpdated);
129        }
130        Ok(UpdateResult {
131            rows_affected: result.rows_affected(),
132        })
133    }
134
135    async fn exec_update_and_return_updated<A, C>(
136        mut self,
137        model: A,
138        db: &C,
139    ) -> Result<<A::Entity as EntityTrait>::Model, DbErr>
140    where
141        A: ActiveModelTrait,
142        C: ConnectionTrait,
143    {
144        type Entity<A> = <A as ActiveModelTrait>::Entity;
145        type Model<A> = <Entity<A> as EntityTrait>::Model;
146        type Column<A> = <Entity<A> as EntityTrait>::Column;
147
148        if self.is_noop() {
149            return find_updated_model_by_id(model, db).await;
150        }
151
152        match db.support_returning() {
153            true => {
154                let db_backend = db.get_database_backend();
155                let returning = Query::returning().exprs(
156                    Column::<A>::iter().map(|c| c.select_as(c.into_returning_expr(db_backend))),
157                );
158                self.query.returning(returning);
159                let found: Option<Model<A>> =
160                    ReturningSelector::<SelectModel<Model<A>>, _>::from_query(self.query)
161                        .one(db)
162                        .await?;
163                // If we got `None` then we are updating a row that does not exist.
164                match found {
165                    Some(model) => Ok(model),
166                    None => Err(DbErr::RecordNotUpdated),
167                }
168            }
169            false => {
170                // If we updating a row that does not exist then an error will be thrown here.
171                self.check_record_exists = true;
172                self.exec(db).await?;
173                find_updated_model_by_id(model, db).await
174            }
175        }
176    }
177
178    async fn exec_update_with_returning<E, C>(mut self, db: &C) -> Result<Vec<E::Model>, DbErr>
179    where
180        E: EntityTrait,
181        C: ConnectionTrait,
182    {
183        if self.is_noop() {
184            return Ok(vec![]);
185        }
186
187        let db_backend = db.get_database_backend();
188        match db.support_returning() {
189            true => {
190                let returning = Query::returning().exprs(
191                    E::Column::iter().map(|c| c.select_as(c.into_returning_expr(db_backend))),
192                );
193                self.query.returning(returning);
194                let models: Vec<E::Model> =
195                    ReturningSelector::<SelectModel<E::Model>, _>::from_query(self.query)
196                        .all(db)
197                        .await?;
198                Ok(models)
199            }
200            false => Err(DbErr::BackendNotSupported {
201                db: db_backend.as_str(),
202                ctx: "UPDATE RETURNING",
203            }),
204        }
205    }
206
207    fn is_noop(&self) -> bool {
208        self.query.get_values().is_empty()
209    }
210}
211
212async fn find_updated_model_by_id<A, C>(
213    model: A,
214    db: &C,
215) -> Result<<A::Entity as EntityTrait>::Model, DbErr>
216where
217    A: ActiveModelTrait,
218    C: ConnectionTrait,
219{
220    type Entity<A> = <A as ActiveModelTrait>::Entity;
221    type ValueType<A> = <<Entity<A> as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::ValueType;
222
223    let primary_key_value = match model.get_primary_key_value() {
224        Some(val) => ValueType::<A>::from_value_tuple(val),
225        None => return Err(DbErr::UpdateGetPrimaryKey),
226    };
227    let found = Entity::<A>::find_by_id(primary_key_value).one(db).await?;
228    // If we cannot select the updated row from db by the cached primary key
229    match found {
230        Some(model) => Ok(model),
231        None => Err(DbErr::RecordNotFound(
232            "Failed to find updated item".to_owned(),
233        )),
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use crate::{
240        ColumnTrait, DbBackend, DbErr, EntityTrait, IntoActiveModel, MockDatabase, MockExecResult,
241        QueryFilter, Set, Transaction, Update, UpdateResult, tests_cfg::cake,
242    };
243    use pretty_assertions::assert_eq;
244    use sea_query::Expr;
245
246    #[smol_potat::test]
247    async fn update_record_not_found_1() -> Result<(), DbErr> {
248        use crate::ActiveModelTrait;
249
250        let updated_cake = cake::Model {
251            id: 1,
252            name: "Cheese Cake".to_owned(),
253        };
254
255        let db = MockDatabase::new(DbBackend::Postgres)
256            .append_query_results([
257                vec![updated_cake.clone()],
258                vec![],
259                vec![],
260                vec![],
261                vec![updated_cake.clone()],
262                vec![updated_cake.clone()],
263                vec![updated_cake.clone()],
264            ])
265            .append_exec_results([MockExecResult {
266                last_insert_id: 0,
267                rows_affected: 0,
268            }])
269            .into_connection();
270
271        let model = cake::Model {
272            id: 1,
273            name: "New York Cheese".to_owned(),
274        };
275
276        assert_eq!(
277            cake::ActiveModel {
278                name: Set("Cheese Cake".to_owned()),
279                ..model.clone().into_active_model()
280            }
281            .update(&db)
282            .await?,
283            cake::Model {
284                id: 1,
285                name: "Cheese Cake".to_owned(),
286            }
287        );
288
289        let model = cake::Model {
290            id: 2,
291            name: "New York Cheese".to_owned(),
292        };
293
294        assert_eq!(
295            cake::ActiveModel {
296                name: Set("Cheese Cake".to_owned()),
297                ..model.clone().into_active_model()
298            }
299            .update(&db)
300            .await,
301            Err(DbErr::RecordNotUpdated)
302        );
303
304        assert_eq!(
305            cake::Entity::update(cake::ActiveModel {
306                name: Set("Cheese Cake".to_owned()),
307                ..model.clone().into_active_model()
308            })
309            .exec(&db)
310            .await,
311            Err(DbErr::RecordNotUpdated)
312        );
313
314        assert_eq!(
315            Update::one(cake::ActiveModel {
316                name: Set("Cheese Cake".to_owned()),
317                ..model.clone().into_active_model()
318            })
319            .exec(&db)
320            .await,
321            Err(DbErr::RecordNotUpdated)
322        );
323
324        assert_eq!(
325            Update::many(cake::Entity)
326                .col_expr(cake::Column::Name, Expr::value("Cheese Cake".to_owned()))
327                .filter(cake::Column::Id.eq(2))
328                .exec(&db)
329                .await,
330            Ok(UpdateResult { rows_affected: 0 })
331        );
332
333        assert_eq!(
334            updated_cake.clone().into_active_model().save(&db).await?,
335            updated_cake.clone().into_active_model()
336        );
337
338        assert_eq!(
339            updated_cake.clone().into_active_model().update(&db).await?,
340            updated_cake
341        );
342
343        assert_eq!(
344            cake::Entity::update(updated_cake.clone().into_active_model())
345                .exec(&db)
346                .await?,
347            updated_cake
348        );
349
350        assert_eq!(
351            cake::Entity::update_many().exec(&db).await?.rows_affected,
352            0
353        );
354
355        assert_eq!(
356            db.into_transaction_log(),
357            [
358                Transaction::from_sql_and_values(
359                    DbBackend::Postgres,
360                    r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2 RETURNING "id", "name""#,
361                    ["Cheese Cake".into(), 1i32.into()]
362                ),
363                Transaction::from_sql_and_values(
364                    DbBackend::Postgres,
365                    r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2 RETURNING "id", "name""#,
366                    ["Cheese Cake".into(), 2i32.into()]
367                ),
368                Transaction::from_sql_and_values(
369                    DbBackend::Postgres,
370                    r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2 RETURNING "id", "name""#,
371                    ["Cheese Cake".into(), 2i32.into()]
372                ),
373                Transaction::from_sql_and_values(
374                    DbBackend::Postgres,
375                    r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2 RETURNING "id", "name""#,
376                    ["Cheese Cake".into(), 2i32.into()]
377                ),
378                Transaction::from_sql_and_values(
379                    DbBackend::Postgres,
380                    r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#,
381                    ["Cheese Cake".into(), 2i32.into()]
382                ),
383                Transaction::from_sql_and_values(
384                    DbBackend::Postgres,
385                    r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "cake"."id" = $1 LIMIT $2"#,
386                    [1.into(), 1u64.into()]
387                ),
388                Transaction::from_sql_and_values(
389                    DbBackend::Postgres,
390                    r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "cake"."id" = $1 LIMIT $2"#,
391                    [1.into(), 1u64.into()]
392                ),
393                Transaction::from_sql_and_values(
394                    DbBackend::Postgres,
395                    r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "cake"."id" = $1 LIMIT $2"#,
396                    [1.into(), 1u64.into()]
397                ),
398            ]
399        );
400
401        Ok(())
402    }
403
404    #[smol_potat::test]
405    async fn update_error() {
406        use crate::{DbBackend, DbErr, MockDatabase};
407
408        let db = MockDatabase::new(DbBackend::MySql).into_connection();
409
410        assert!(matches!(
411            Update::one(cake::ActiveModel {
412                ..Default::default()
413            })
414            .exec(&db)
415            .await,
416            Err(DbErr::PrimaryKeyNotSet { .. })
417        ));
418
419        assert!(matches!(
420            cake::Entity::update(cake::ActiveModel::default())
421                .exec(&db)
422                .await,
423            Err(DbErr::PrimaryKeyNotSet { .. })
424        ));
425    }
426}