1use crate::{
2 error::*, ActiveModelTrait, ColumnTrait, ConnectionTrait, EntityTrait, IntoActiveModel,
3 Iterable, PrimaryKeyTrait, SelectModel, SelectorRaw, UpdateMany, UpdateOne,
4};
5use sea_query::{FromValueTuple, Query, UpdateStatement};
6
7#[derive(Clone, Debug)]
9pub struct Updater {
10 query: UpdateStatement,
11 check_record_exists: bool,
12}
13
14#[derive(Clone, Debug, PartialEq, Eq, Default)]
16pub struct UpdateResult {
17 pub rows_affected: u64,
19}
20
21impl<A> UpdateOne<A>
22where
23 A: ActiveModelTrait,
24{
25 pub async fn exec<C>(self, db: &C) -> Result<<A::Entity as EntityTrait>::Model, DbErr>
27 where
28 <A::Entity as EntityTrait>::Model: IntoActiveModel<A>,
29 C: ConnectionTrait,
30 {
31 Updater::new(self.query)
32 .exec_update_and_return_updated(self.model, db)
33 .await
34 }
35}
36
37impl<'a, E> UpdateMany<E>
38where
39 E: EntityTrait,
40{
41 pub async fn exec<C>(self, db: &'a C) -> Result<UpdateResult, DbErr>
43 where
44 C: ConnectionTrait,
45 {
46 Updater::new(self.query).exec(db).await
47 }
48
49 pub async fn exec_with_returning<C>(self, db: &'a C) -> Result<Vec<E::Model>, DbErr>
55 where
56 C: ConnectionTrait,
57 {
58 Updater::new(self.query)
59 .exec_update_with_returning::<E, _>(db)
60 .await
61 }
62}
63
64impl Updater {
65 pub fn new(query: UpdateStatement) -> Self {
67 Self {
68 query,
69 check_record_exists: false,
70 }
71 }
72
73 pub fn check_record_exists(mut self) -> Self {
75 self.check_record_exists = true;
76 self
77 }
78
79 pub async fn exec<C>(self, db: &C) -> Result<UpdateResult, DbErr>
81 where
82 C: ConnectionTrait,
83 {
84 if self.is_noop() {
85 return Ok(UpdateResult::default());
86 }
87 let builder = db.get_database_backend();
88 let statement = builder.build(&self.query);
89 let result = db.execute(statement).await?;
90 if self.check_record_exists && result.rows_affected() == 0 {
91 return Err(DbErr::RecordNotUpdated);
92 }
93 Ok(UpdateResult {
94 rows_affected: result.rows_affected(),
95 })
96 }
97
98 async fn exec_update_and_return_updated<A, C>(
99 mut self,
100 model: A,
101 db: &C,
102 ) -> Result<<A::Entity as EntityTrait>::Model, DbErr>
103 where
104 A: ActiveModelTrait,
105 C: ConnectionTrait,
106 {
107 type Entity<A> = <A as ActiveModelTrait>::Entity;
108 type Model<A> = <Entity<A> as EntityTrait>::Model;
109 type Column<A> = <Entity<A> as EntityTrait>::Column;
110
111 if self.is_noop() {
112 return find_updated_model_by_id(model, db).await;
113 }
114
115 match db.support_returning() {
116 true => {
117 let db_backend = db.get_database_backend();
118 let returning = Query::returning().exprs(
119 Column::<A>::iter().map(|c| c.select_as(c.into_returning_expr(db_backend))),
120 );
121 self.query.returning(returning);
122 let found: Option<Model<A>> = SelectorRaw::<SelectModel<Model<A>>>::from_statement(
123 db_backend.build(&self.query),
124 )
125 .one(db)
126 .await?;
127 match found {
129 Some(model) => Ok(model),
130 None => Err(DbErr::RecordNotUpdated),
131 }
132 }
133 false => {
134 self.check_record_exists().exec(db).await?;
136 find_updated_model_by_id(model, db).await
137 }
138 }
139 }
140
141 async fn exec_update_with_returning<E, C>(mut self, db: &C) -> Result<Vec<E::Model>, DbErr>
142 where
143 E: EntityTrait,
144 C: ConnectionTrait,
145 {
146 if self.is_noop() {
147 return Ok(vec![]);
148 }
149
150 match db.support_returning() {
151 true => {
152 let db_backend = db.get_database_backend();
153 let returning = Query::returning().exprs(
154 E::Column::iter().map(|c| c.select_as(c.into_returning_expr(db_backend))),
155 );
156 self.query.returning(returning);
157 let models: Vec<E::Model> = SelectorRaw::<SelectModel<E::Model>>::from_statement(
158 db_backend.build(&self.query),
159 )
160 .all(db)
161 .await?;
162 Ok(models)
163 }
164 false => unimplemented!("Database backend doesn't support RETURNING"),
165 }
166 }
167
168 fn is_noop(&self) -> bool {
169 self.query.get_values().is_empty()
170 }
171}
172
173async fn find_updated_model_by_id<A, C>(
174 model: A,
175 db: &C,
176) -> Result<<A::Entity as EntityTrait>::Model, DbErr>
177where
178 A: ActiveModelTrait,
179 C: ConnectionTrait,
180{
181 type Entity<A> = <A as ActiveModelTrait>::Entity;
182 type ValueType<A> = <<Entity<A> as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::ValueType;
183
184 let primary_key_value = match model.get_primary_key_value() {
185 Some(val) => ValueType::<A>::from_value_tuple(val),
186 None => return Err(DbErr::UpdateGetPrimaryKey),
187 };
188 let found = Entity::<A>::find_by_id(primary_key_value).one(db).await?;
189 match found {
191 Some(model) => Ok(model),
192 None => Err(DbErr::RecordNotFound(
193 "Failed to find updated item".to_owned(),
194 )),
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use crate::{entity::prelude::*, tests_cfg::*, *};
201 use pretty_assertions::assert_eq;
202 use sea_query::Expr;
203
204 #[smol_potat::test]
205 async fn update_record_not_found_1() -> Result<(), DbErr> {
206 let updated_cake = cake::Model {
207 id: 1,
208 name: "Cheese Cake".to_owned(),
209 };
210
211 let db = MockDatabase::new(DbBackend::Postgres)
212 .append_query_results([
213 vec![updated_cake.clone()],
214 vec![],
215 vec![],
216 vec![],
217 vec![updated_cake.clone()],
218 vec![updated_cake.clone()],
219 vec![updated_cake.clone()],
220 ])
221 .append_exec_results([MockExecResult {
222 last_insert_id: 0,
223 rows_affected: 0,
224 }])
225 .into_connection();
226
227 let model = cake::Model {
228 id: 1,
229 name: "New York Cheese".to_owned(),
230 };
231
232 assert_eq!(
233 cake::ActiveModel {
234 name: Set("Cheese Cake".to_owned()),
235 ..model.clone().into_active_model()
236 }
237 .update(&db)
238 .await?,
239 cake::Model {
240 id: 1,
241 name: "Cheese Cake".to_owned(),
242 }
243 );
244
245 let model = cake::Model {
246 id: 2,
247 name: "New York Cheese".to_owned(),
248 };
249
250 assert_eq!(
251 cake::ActiveModel {
252 name: Set("Cheese Cake".to_owned()),
253 ..model.clone().into_active_model()
254 }
255 .update(&db)
256 .await,
257 Err(DbErr::RecordNotUpdated)
258 );
259
260 assert_eq!(
261 cake::Entity::update(cake::ActiveModel {
262 name: Set("Cheese Cake".to_owned()),
263 ..model.clone().into_active_model()
264 })
265 .exec(&db)
266 .await,
267 Err(DbErr::RecordNotUpdated)
268 );
269
270 assert_eq!(
271 Update::one(cake::ActiveModel {
272 name: Set("Cheese Cake".to_owned()),
273 ..model.clone().into_active_model()
274 })
275 .exec(&db)
276 .await,
277 Err(DbErr::RecordNotUpdated)
278 );
279
280 assert_eq!(
281 Update::many(cake::Entity)
282 .col_expr(cake::Column::Name, Expr::value("Cheese Cake".to_owned()))
283 .filter(cake::Column::Id.eq(2))
284 .exec(&db)
285 .await,
286 Ok(UpdateResult { rows_affected: 0 })
287 );
288
289 assert_eq!(
290 updated_cake.clone().into_active_model().save(&db).await?,
291 updated_cake.clone().into_active_model()
292 );
293
294 assert_eq!(
295 updated_cake.clone().into_active_model().update(&db).await?,
296 updated_cake
297 );
298
299 assert_eq!(
300 cake::Entity::update(updated_cake.clone().into_active_model())
301 .exec(&db)
302 .await?,
303 updated_cake
304 );
305
306 assert_eq!(
307 cake::Entity::update_many().exec(&db).await?.rows_affected,
308 0
309 );
310
311 assert_eq!(
312 db.into_transaction_log(),
313 [
314 Transaction::from_sql_and_values(
315 DbBackend::Postgres,
316 r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2 RETURNING "id", "name""#,
317 ["Cheese Cake".into(), 1i32.into()]
318 ),
319 Transaction::from_sql_and_values(
320 DbBackend::Postgres,
321 r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2 RETURNING "id", "name""#,
322 ["Cheese Cake".into(), 2i32.into()]
323 ),
324 Transaction::from_sql_and_values(
325 DbBackend::Postgres,
326 r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2 RETURNING "id", "name""#,
327 ["Cheese Cake".into(), 2i32.into()]
328 ),
329 Transaction::from_sql_and_values(
330 DbBackend::Postgres,
331 r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2 RETURNING "id", "name""#,
332 ["Cheese Cake".into(), 2i32.into()]
333 ),
334 Transaction::from_sql_and_values(
335 DbBackend::Postgres,
336 r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#,
337 ["Cheese Cake".into(), 2i32.into()]
338 ),
339 Transaction::from_sql_and_values(
340 DbBackend::Postgres,
341 r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "cake"."id" = $1 LIMIT $2"#,
342 [1.into(), 1u64.into()]
343 ),
344 Transaction::from_sql_and_values(
345 DbBackend::Postgres,
346 r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "cake"."id" = $1 LIMIT $2"#,
347 [1.into(), 1u64.into()]
348 ),
349 Transaction::from_sql_and_values(
350 DbBackend::Postgres,
351 r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "cake"."id" = $1 LIMIT $2"#,
352 [1.into(), 1u64.into()]
353 ),
354 ]
355 );
356
357 Ok(())
358 }
359}