sea_orm/executor/
update.rs

1use crate::{
2    error::*, ActiveModelTrait, ColumnTrait, ConnectionTrait, EntityTrait, IntoActiveModel,
3    Iterable, PrimaryKeyTrait, SelectModel, SelectorRaw, UpdateMany, UpdateOne,
4};
5use sea_query::{FromValueTuple, Query, UpdateStatement};
6
7/// Defines an update operation
8#[derive(Clone, Debug)]
9pub struct Updater {
10    query: UpdateStatement,
11    check_record_exists: bool,
12}
13
14/// The result of an update operation on an ActiveModel
15#[derive(Clone, Debug, PartialEq, Eq, Default)]
16pub struct UpdateResult {
17    /// The rows affected by the update operation
18    pub rows_affected: u64,
19}
20
21impl<A> UpdateOne<A>
22where
23    A: ActiveModelTrait,
24{
25    /// Execute an update operation on an ActiveModel
26    pub async fn exec<C>(self, db: &C) -> Result<<A::Entity as EntityTrait>::Model, DbErr>
27    where
28        <A::Entity as EntityTrait>::Model: IntoActiveModel<A>,
29        C: ConnectionTrait,
30    {
31        Updater::new(self.query)
32            .exec_update_and_return_updated(self.model, db)
33            .await
34    }
35}
36
37impl<'a, E> UpdateMany<E>
38where
39    E: EntityTrait,
40{
41    /// Execute an update operation on multiple ActiveModels
42    pub async fn exec<C>(self, db: &'a C) -> Result<UpdateResult, DbErr>
43    where
44        C: ConnectionTrait,
45    {
46        Updater::new(self.query).exec(db).await
47    }
48
49    /// Execute an update operation and return the updated model (use `RETURNING` syntax if supported)
50    ///
51    /// # Panics
52    ///
53    /// Panics if the database backend does not support `UPDATE RETURNING`.
54    pub async fn exec_with_returning<C>(self, db: &'a C) -> Result<Vec<E::Model>, DbErr>
55    where
56        C: ConnectionTrait,
57    {
58        Updater::new(self.query)
59            .exec_update_with_returning::<E, _>(db)
60            .await
61    }
62}
63
64impl Updater {
65    /// Instantiate an update using an [UpdateStatement]
66    pub fn new(query: UpdateStatement) -> Self {
67        Self {
68            query,
69            check_record_exists: false,
70        }
71    }
72
73    /// Check if a record exists on the ActiveModel to perform the update operation on
74    pub fn check_record_exists(mut self) -> Self {
75        self.check_record_exists = true;
76        self
77    }
78
79    /// Execute an update operation
80    pub async fn exec<C>(self, db: &C) -> Result<UpdateResult, DbErr>
81    where
82        C: ConnectionTrait,
83    {
84        if self.is_noop() {
85            return Ok(UpdateResult::default());
86        }
87        let builder = db.get_database_backend();
88        let statement = builder.build(&self.query);
89        let result = db.execute(statement).await?;
90        if self.check_record_exists && result.rows_affected() == 0 {
91            return Err(DbErr::RecordNotUpdated);
92        }
93        Ok(UpdateResult {
94            rows_affected: result.rows_affected(),
95        })
96    }
97
98    async fn exec_update_and_return_updated<A, C>(
99        mut self,
100        model: A,
101        db: &C,
102    ) -> Result<<A::Entity as EntityTrait>::Model, DbErr>
103    where
104        A: ActiveModelTrait,
105        C: ConnectionTrait,
106    {
107        type Entity<A> = <A as ActiveModelTrait>::Entity;
108        type Model<A> = <Entity<A> as EntityTrait>::Model;
109        type Column<A> = <Entity<A> as EntityTrait>::Column;
110
111        if self.is_noop() {
112            return find_updated_model_by_id(model, db).await;
113        }
114
115        match db.support_returning() {
116            true => {
117                let db_backend = db.get_database_backend();
118                let returning = Query::returning().exprs(
119                    Column::<A>::iter().map(|c| c.select_as(c.into_returning_expr(db_backend))),
120                );
121                self.query.returning(returning);
122                let found: Option<Model<A>> = SelectorRaw::<SelectModel<Model<A>>>::from_statement(
123                    db_backend.build(&self.query),
124                )
125                .one(db)
126                .await?;
127                // If we got `None` then we are updating a row that does not exist.
128                match found {
129                    Some(model) => Ok(model),
130                    None => Err(DbErr::RecordNotUpdated),
131                }
132            }
133            false => {
134                // If we updating a row that does not exist then an error will be thrown here.
135                self.check_record_exists().exec(db).await?;
136                find_updated_model_by_id(model, db).await
137            }
138        }
139    }
140
141    async fn exec_update_with_returning<E, C>(mut self, db: &C) -> Result<Vec<E::Model>, DbErr>
142    where
143        E: EntityTrait,
144        C: ConnectionTrait,
145    {
146        if self.is_noop() {
147            return Ok(vec![]);
148        }
149
150        match db.support_returning() {
151            true => {
152                let db_backend = db.get_database_backend();
153                let returning = Query::returning().exprs(
154                    E::Column::iter().map(|c| c.select_as(c.into_returning_expr(db_backend))),
155                );
156                self.query.returning(returning);
157                let models: Vec<E::Model> = SelectorRaw::<SelectModel<E::Model>>::from_statement(
158                    db_backend.build(&self.query),
159                )
160                .all(db)
161                .await?;
162                Ok(models)
163            }
164            false => unimplemented!("Database backend doesn't support RETURNING"),
165        }
166    }
167
168    fn is_noop(&self) -> bool {
169        self.query.get_values().is_empty()
170    }
171}
172
173async fn find_updated_model_by_id<A, C>(
174    model: A,
175    db: &C,
176) -> Result<<A::Entity as EntityTrait>::Model, DbErr>
177where
178    A: ActiveModelTrait,
179    C: ConnectionTrait,
180{
181    type Entity<A> = <A as ActiveModelTrait>::Entity;
182    type ValueType<A> = <<Entity<A> as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::ValueType;
183
184    let primary_key_value = match model.get_primary_key_value() {
185        Some(val) => ValueType::<A>::from_value_tuple(val),
186        None => return Err(DbErr::UpdateGetPrimaryKey),
187    };
188    let found = Entity::<A>::find_by_id(primary_key_value).one(db).await?;
189    // If we cannot select the updated row from db by the cached primary key
190    match found {
191        Some(model) => Ok(model),
192        None => Err(DbErr::RecordNotFound(
193            "Failed to find updated item".to_owned(),
194        )),
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use crate::{entity::prelude::*, tests_cfg::*, *};
201    use pretty_assertions::assert_eq;
202    use sea_query::Expr;
203
204    #[smol_potat::test]
205    async fn update_record_not_found_1() -> Result<(), DbErr> {
206        let updated_cake = cake::Model {
207            id: 1,
208            name: "Cheese Cake".to_owned(),
209        };
210
211        let db = MockDatabase::new(DbBackend::Postgres)
212            .append_query_results([
213                vec![updated_cake.clone()],
214                vec![],
215                vec![],
216                vec![],
217                vec![updated_cake.clone()],
218                vec![updated_cake.clone()],
219                vec![updated_cake.clone()],
220            ])
221            .append_exec_results([MockExecResult {
222                last_insert_id: 0,
223                rows_affected: 0,
224            }])
225            .into_connection();
226
227        let model = cake::Model {
228            id: 1,
229            name: "New York Cheese".to_owned(),
230        };
231
232        assert_eq!(
233            cake::ActiveModel {
234                name: Set("Cheese Cake".to_owned()),
235                ..model.clone().into_active_model()
236            }
237            .update(&db)
238            .await?,
239            cake::Model {
240                id: 1,
241                name: "Cheese Cake".to_owned(),
242            }
243        );
244
245        let model = cake::Model {
246            id: 2,
247            name: "New York Cheese".to_owned(),
248        };
249
250        assert_eq!(
251            cake::ActiveModel {
252                name: Set("Cheese Cake".to_owned()),
253                ..model.clone().into_active_model()
254            }
255            .update(&db)
256            .await,
257            Err(DbErr::RecordNotUpdated)
258        );
259
260        assert_eq!(
261            cake::Entity::update(cake::ActiveModel {
262                name: Set("Cheese Cake".to_owned()),
263                ..model.clone().into_active_model()
264            })
265            .exec(&db)
266            .await,
267            Err(DbErr::RecordNotUpdated)
268        );
269
270        assert_eq!(
271            Update::one(cake::ActiveModel {
272                name: Set("Cheese Cake".to_owned()),
273                ..model.clone().into_active_model()
274            })
275            .exec(&db)
276            .await,
277            Err(DbErr::RecordNotUpdated)
278        );
279
280        assert_eq!(
281            Update::many(cake::Entity)
282                .col_expr(cake::Column::Name, Expr::value("Cheese Cake".to_owned()))
283                .filter(cake::Column::Id.eq(2))
284                .exec(&db)
285                .await,
286            Ok(UpdateResult { rows_affected: 0 })
287        );
288
289        assert_eq!(
290            updated_cake.clone().into_active_model().save(&db).await?,
291            updated_cake.clone().into_active_model()
292        );
293
294        assert_eq!(
295            updated_cake.clone().into_active_model().update(&db).await?,
296            updated_cake
297        );
298
299        assert_eq!(
300            cake::Entity::update(updated_cake.clone().into_active_model())
301                .exec(&db)
302                .await?,
303            updated_cake
304        );
305
306        assert_eq!(
307            cake::Entity::update_many().exec(&db).await?.rows_affected,
308            0
309        );
310
311        assert_eq!(
312            db.into_transaction_log(),
313            [
314                Transaction::from_sql_and_values(
315                    DbBackend::Postgres,
316                    r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2 RETURNING "id", "name""#,
317                    ["Cheese Cake".into(), 1i32.into()]
318                ),
319                Transaction::from_sql_and_values(
320                    DbBackend::Postgres,
321                    r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2 RETURNING "id", "name""#,
322                    ["Cheese Cake".into(), 2i32.into()]
323                ),
324                Transaction::from_sql_and_values(
325                    DbBackend::Postgres,
326                    r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2 RETURNING "id", "name""#,
327                    ["Cheese Cake".into(), 2i32.into()]
328                ),
329                Transaction::from_sql_and_values(
330                    DbBackend::Postgres,
331                    r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2 RETURNING "id", "name""#,
332                    ["Cheese Cake".into(), 2i32.into()]
333                ),
334                Transaction::from_sql_and_values(
335                    DbBackend::Postgres,
336                    r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#,
337                    ["Cheese Cake".into(), 2i32.into()]
338                ),
339                Transaction::from_sql_and_values(
340                    DbBackend::Postgres,
341                    r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "cake"."id" = $1 LIMIT $2"#,
342                    [1.into(), 1u64.into()]
343                ),
344                Transaction::from_sql_and_values(
345                    DbBackend::Postgres,
346                    r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "cake"."id" = $1 LIMIT $2"#,
347                    [1.into(), 1u64.into()]
348                ),
349                Transaction::from_sql_and_values(
350                    DbBackend::Postgres,
351                    r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "cake"."id" = $1 LIMIT $2"#,
352                    [1.into(), 1u64.into()]
353                ),
354            ]
355        );
356
357        Ok(())
358    }
359}