1use super::ReturningSelector;
2use crate::{
3 ActiveModelTrait, ColumnTrait, ConnectionTrait, EntityTrait, IntoActiveModel, Iterable,
4 PrimaryKeyTrait, SelectModel, UpdateMany, UpdateOne, ValidatedUpdateOne, error::*,
5};
6use sea_query::{FromValueTuple, Query, UpdateStatement};
7
8#[derive(Clone, Debug)]
10pub struct Updater {
11 query: UpdateStatement,
12 check_record_exists: bool,
13}
14
15#[derive(Clone, Debug, PartialEq, Eq, Default)]
17pub struct UpdateResult {
18 pub rows_affected: u64,
20}
21
22impl<A> ValidatedUpdateOne<A>
23where
24 A: ActiveModelTrait,
25{
26 pub async fn exec<C>(self, db: &C) -> Result<<A::Entity as EntityTrait>::Model, DbErr>
28 where
29 <A::Entity as EntityTrait>::Model: IntoActiveModel<A>,
30 C: ConnectionTrait,
31 {
32 Updater::new(self.query)
33 .exec_update_and_return_updated(self.model, db)
34 .await
35 }
36}
37
38impl<A> UpdateOne<A>
39where
40 A: ActiveModelTrait,
41{
42 pub async fn exec<C>(self, db: &C) -> Result<<A::Entity as EntityTrait>::Model, DbErr>
44 where
45 <A::Entity as EntityTrait>::Model: IntoActiveModel<A>,
46 C: ConnectionTrait,
47 {
48 self.0?.exec(db).await
49 }
50}
51
52impl<'a, E> UpdateMany<E>
53where
54 E: EntityTrait,
55{
56 pub async fn exec<C>(self, db: &'a C) -> Result<UpdateResult, DbErr>
58 where
59 C: ConnectionTrait,
60 {
61 Updater::new(self.query).exec(db).await
62 }
63
64 pub async fn exec_with_returning<C>(self, db: &'a C) -> Result<Vec<E::Model>, DbErr>
66 where
67 C: ConnectionTrait,
68 {
69 Updater::new(self.query)
70 .exec_update_with_returning::<E, _>(db)
71 .await
72 }
73}
74
75impl Updater {
76 fn new(query: UpdateStatement) -> Self {
78 Self {
79 query,
80 check_record_exists: false,
81 }
82 }
83
84 pub async fn exec<C>(self, db: &C) -> Result<UpdateResult, DbErr>
86 where
87 C: ConnectionTrait,
88 {
89 if self.is_noop() {
90 return Ok(UpdateResult::default());
91 }
92 let result = db.execute(&self.query).await?;
93 if self.check_record_exists && result.rows_affected() == 0 {
94 return Err(DbErr::RecordNotUpdated);
95 }
96 Ok(UpdateResult {
97 rows_affected: result.rows_affected(),
98 })
99 }
100
101 async fn exec_update_and_return_updated<A, C>(
102 mut self,
103 model: A,
104 db: &C,
105 ) -> Result<<A::Entity as EntityTrait>::Model, DbErr>
106 where
107 A: ActiveModelTrait,
108 C: ConnectionTrait,
109 {
110 type Entity<A> = <A as ActiveModelTrait>::Entity;
111 type Model<A> = <Entity<A> as EntityTrait>::Model;
112 type Column<A> = <Entity<A> as EntityTrait>::Column;
113
114 if self.is_noop() {
115 return find_updated_model_by_id(model, db).await;
116 }
117
118 match db.support_returning() {
119 true => {
120 let db_backend = db.get_database_backend();
121 let returning = Query::returning().exprs(
122 Column::<A>::iter().map(|c| c.select_as(c.into_returning_expr(db_backend))),
123 );
124 self.query.returning(returning);
125 let found: Option<Model<A>> =
126 ReturningSelector::<SelectModel<Model<A>>, _>::from_query(self.query)
127 .one(db)
128 .await?;
129 match found {
131 Some(model) => Ok(model),
132 None => Err(DbErr::RecordNotUpdated),
133 }
134 }
135 false => {
136 self.check_record_exists = true;
138 self.exec(db).await?;
139 find_updated_model_by_id(model, db).await
140 }
141 }
142 }
143
144 async fn exec_update_with_returning<E, C>(mut self, db: &C) -> Result<Vec<E::Model>, DbErr>
145 where
146 E: EntityTrait,
147 C: ConnectionTrait,
148 {
149 if self.is_noop() {
150 return Ok(vec![]);
151 }
152
153 let db_backend = db.get_database_backend();
154 match db.support_returning() {
155 true => {
156 let returning = Query::returning().exprs(
157 E::Column::iter().map(|c| c.select_as(c.into_returning_expr(db_backend))),
158 );
159 self.query.returning(returning);
160 let models: Vec<E::Model> =
161 ReturningSelector::<SelectModel<E::Model>, _>::from_query(self.query)
162 .all(db)
163 .await?;
164 Ok(models)
165 }
166 false => Err(DbErr::BackendNotSupported {
167 db: db_backend.as_str(),
168 ctx: "UPDATE RETURNING",
169 }),
170 }
171 }
172
173 fn is_noop(&self) -> bool {
174 self.query.get_values().is_empty()
175 }
176}
177
178async fn find_updated_model_by_id<A, C>(
179 model: A,
180 db: &C,
181) -> Result<<A::Entity as EntityTrait>::Model, DbErr>
182where
183 A: ActiveModelTrait,
184 C: ConnectionTrait,
185{
186 type Entity<A> = <A as ActiveModelTrait>::Entity;
187 type ValueType<A> = <<Entity<A> as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::ValueType;
188
189 let primary_key_value = match model.get_primary_key_value() {
190 Some(val) => ValueType::<A>::from_value_tuple(val),
191 None => return Err(DbErr::UpdateGetPrimaryKey),
192 };
193 let found = Entity::<A>::find_by_id(primary_key_value).one(db).await?;
194 match found {
196 Some(model) => Ok(model),
197 None => Err(DbErr::RecordNotFound(
198 "Failed to find updated item".to_owned(),
199 )),
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use crate::{
206 ColumnTrait, DbBackend, DbErr, EntityTrait, IntoActiveModel, MockDatabase, MockExecResult,
207 QueryFilter, Set, Transaction, Update, UpdateResult, tests_cfg::cake,
208 };
209 use pretty_assertions::assert_eq;
210 use sea_query::Expr;
211
212 #[smol_potat::test]
213 async fn update_record_not_found_1() -> Result<(), DbErr> {
214 use crate::ActiveModelTrait;
215
216 let updated_cake = cake::Model {
217 id: 1,
218 name: "Cheese Cake".to_owned(),
219 };
220
221 let db = MockDatabase::new(DbBackend::Postgres)
222 .append_query_results([
223 vec![updated_cake.clone()],
224 vec![],
225 vec![],
226 vec![],
227 vec![updated_cake.clone()],
228 vec![updated_cake.clone()],
229 vec![updated_cake.clone()],
230 ])
231 .append_exec_results([MockExecResult {
232 last_insert_id: 0,
233 rows_affected: 0,
234 }])
235 .into_connection();
236
237 let model = cake::Model {
238 id: 1,
239 name: "New York Cheese".to_owned(),
240 };
241
242 assert_eq!(
243 cake::ActiveModel {
244 name: Set("Cheese Cake".to_owned()),
245 ..model.clone().into_active_model()
246 }
247 .update(&db)
248 .await?,
249 cake::Model {
250 id: 1,
251 name: "Cheese Cake".to_owned(),
252 }
253 );
254
255 let model = cake::Model {
256 id: 2,
257 name: "New York Cheese".to_owned(),
258 };
259
260 assert_eq!(
261 cake::ActiveModel {
262 name: Set("Cheese Cake".to_owned()),
263 ..model.clone().into_active_model()
264 }
265 .update(&db)
266 .await,
267 Err(DbErr::RecordNotUpdated)
268 );
269
270 assert_eq!(
271 cake::Entity::update(cake::ActiveModel {
272 name: Set("Cheese Cake".to_owned()),
273 ..model.clone().into_active_model()
274 })
275 .exec(&db)
276 .await,
277 Err(DbErr::RecordNotUpdated)
278 );
279
280 assert_eq!(
281 Update::one(cake::ActiveModel {
282 name: Set("Cheese Cake".to_owned()),
283 ..model.clone().into_active_model()
284 })
285 .exec(&db)
286 .await,
287 Err(DbErr::RecordNotUpdated)
288 );
289
290 assert_eq!(
291 Update::many(cake::Entity)
292 .col_expr(cake::Column::Name, Expr::value("Cheese Cake".to_owned()))
293 .filter(cake::Column::Id.eq(2))
294 .exec(&db)
295 .await,
296 Ok(UpdateResult { rows_affected: 0 })
297 );
298
299 assert_eq!(
300 updated_cake.clone().into_active_model().save(&db).await?,
301 updated_cake.clone().into_active_model()
302 );
303
304 assert_eq!(
305 updated_cake.clone().into_active_model().update(&db).await?,
306 updated_cake
307 );
308
309 assert_eq!(
310 cake::Entity::update(updated_cake.clone().into_active_model())
311 .exec(&db)
312 .await?,
313 updated_cake
314 );
315
316 assert_eq!(
317 cake::Entity::update_many().exec(&db).await?.rows_affected,
318 0
319 );
320
321 assert_eq!(
322 db.into_transaction_log(),
323 [
324 Transaction::from_sql_and_values(
325 DbBackend::Postgres,
326 r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2 RETURNING "id", "name""#,
327 ["Cheese Cake".into(), 1i32.into()]
328 ),
329 Transaction::from_sql_and_values(
330 DbBackend::Postgres,
331 r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2 RETURNING "id", "name""#,
332 ["Cheese Cake".into(), 2i32.into()]
333 ),
334 Transaction::from_sql_and_values(
335 DbBackend::Postgres,
336 r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2 RETURNING "id", "name""#,
337 ["Cheese Cake".into(), 2i32.into()]
338 ),
339 Transaction::from_sql_and_values(
340 DbBackend::Postgres,
341 r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2 RETURNING "id", "name""#,
342 ["Cheese Cake".into(), 2i32.into()]
343 ),
344 Transaction::from_sql_and_values(
345 DbBackend::Postgres,
346 r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#,
347 ["Cheese Cake".into(), 2i32.into()]
348 ),
349 Transaction::from_sql_and_values(
350 DbBackend::Postgres,
351 r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "cake"."id" = $1 LIMIT $2"#,
352 [1.into(), 1u64.into()]
353 ),
354 Transaction::from_sql_and_values(
355 DbBackend::Postgres,
356 r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "cake"."id" = $1 LIMIT $2"#,
357 [1.into(), 1u64.into()]
358 ),
359 Transaction::from_sql_and_values(
360 DbBackend::Postgres,
361 r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "cake"."id" = $1 LIMIT $2"#,
362 [1.into(), 1u64.into()]
363 ),
364 ]
365 );
366
367 Ok(())
368 }
369
370 #[smol_potat::test]
371 async fn update_error() {
372 use crate::{DbBackend, DbErr, MockDatabase};
373
374 let db = MockDatabase::new(DbBackend::MySql).into_connection();
375
376 assert!(matches!(
377 Update::one(cake::ActiveModel {
378 ..Default::default()
379 })
380 .exec(&db)
381 .await,
382 Err(DbErr::PrimaryKeyNotSet { .. })
383 ));
384
385 assert!(matches!(
386 cake::Entity::update(cake::ActiveModel::default())
387 .exec(&db)
388 .await,
389 Err(DbErr::PrimaryKeyNotSet { .. })
390 ));
391 }
392}