1use std::collections::HashMap;
2use std::marker::PhantomData;
3
4use rust_decimal::Decimal;
5use sqlx::postgres::PgArguments;
6use sqlx::{Postgres, query::Query as SqlxQuery};
7use uuid::Uuid;
8
9use crate::PgExecutor;
10use crate::filter::{FilterBuilder, compile_filter_sql, schema_model as resolve_schema_model};
11use crate::query::{
12 BoxFuture, QueryFilter, QueryVariableSet, QueryVariableValue, QueryVariables, SchemaAccess,
13 StringValueType, quoted_ident, schema_error,
14};
15use crate::schema::{Field, FieldType, Model, ScalarType, Schema};
16
17pub trait UpdateSpec: Send + Sync {
19 type Output: Send + 'static;
20
21 #[doc(hidden)]
22 fn execute<'a>(
23 &'a self,
24 executor: &'a dyn PgExecutor,
25 ) -> BoxFuture<'a, Result<Self::Output, sqlx::Error>>;
26}
27
28pub trait UpdateManyModel: Sized + Send + 'static {
30 type Schema: SchemaAccess;
31 type Values: UpdateValueSet;
32 type Variables: QueryVariableSet;
33
34 fn model_name() -> &'static str;
35
36 fn filter() -> Option<QueryFilter> {
37 None
38 }
39
40 fn filter_with_variables(_variables: &QueryVariables) -> Option<QueryFilter> {
41 Self::filter()
42 }
43}
44
45pub trait UpdateValueSet: Send + 'static {
47 fn into_update_values(self) -> UpdateValues;
48}
49
50impl UpdateValueSet for UpdateValues {
51 fn into_update_values(self) -> UpdateValues {
52 self
53 }
54}
55
56impl UpdateValueSet for () {
57 fn into_update_values(self) -> UpdateValues {
58 UpdateValues::new()
59 }
60}
61
62pub trait UpdateScalar: Send {
63 fn into_update_value(self) -> UpdateValue;
64}
65
66#[derive(Clone, Debug, Default, PartialEq)]
67pub struct UpdateValues {
68 values: Vec<UpdateFieldValue>,
69 value_indices: HashMap<String, usize>,
70}
71
72impl UpdateValues {
73 pub fn new() -> Self {
74 Self {
75 values: Vec::new(),
76 value_indices: HashMap::new(),
77 }
78 }
79
80 pub fn from_values(values: Vec<(impl Into<String>, UpdateValue)>) -> Self {
81 let mut update_values = Self::new();
82
83 for (name, value) in values {
84 update_values
85 .push(name, value)
86 .expect("update field names must be unique");
87 }
88
89 update_values
90 }
91
92 pub fn push(
93 &mut self,
94 name: impl Into<String>,
95 value: UpdateValue,
96 ) -> Result<usize, sqlx::Error> {
97 let name = name.into();
98
99 if self.value_indices.contains_key(&name) {
100 return Err(schema_error(format!("duplicate update field `{name}`")));
101 }
102
103 let index = self.values.len();
104 self.values.push(UpdateFieldValue {
105 name: name.clone(),
106 value,
107 });
108 self.value_indices.insert(name, index);
109 Ok(index)
110 }
111
112 pub fn get(&self, name: &str) -> Option<&UpdateValue> {
113 self.value_indices
114 .get(name)
115 .and_then(|index| self.values.get(*index))
116 .map(|field| &field.value)
117 }
118
119 pub fn iter(&self) -> impl Iterator<Item = &UpdateFieldValue> {
120 self.values.iter()
121 }
122
123 pub fn len(&self) -> usize {
124 self.values.len()
125 }
126
127 pub fn is_empty(&self) -> bool {
128 self.values.is_empty()
129 }
130}
131
132#[derive(Clone, Debug, PartialEq)]
133pub struct UpdateFieldValue {
134 pub name: String,
135 pub value: UpdateValue,
136}
137
138#[derive(Clone, Debug, PartialEq)]
139pub enum UpdateValue {
140 Null,
141 Int(i64),
142 String(String),
143 Bool(bool),
144 Float(f64),
145 Decimal(Decimal),
146 Bytes(Vec<u8>),
147 DateTime(chrono::DateTime<chrono::Utc>),
148 Uuid(Uuid),
149}
150
151impl From<i64> for UpdateValue {
152 fn from(value: i64) -> Self {
153 Self::Int(value)
154 }
155}
156
157impl From<String> for UpdateValue {
158 fn from(value: String) -> Self {
159 Self::String(value)
160 }
161}
162
163impl From<&str> for UpdateValue {
164 fn from(value: &str) -> Self {
165 Self::String(value.to_owned())
166 }
167}
168
169impl From<bool> for UpdateValue {
170 fn from(value: bool) -> Self {
171 Self::Bool(value)
172 }
173}
174
175impl From<f64> for UpdateValue {
176 fn from(value: f64) -> Self {
177 Self::Float(value)
178 }
179}
180
181impl From<Decimal> for UpdateValue {
182 fn from(value: Decimal) -> Self {
183 Self::Decimal(value)
184 }
185}
186
187impl From<Vec<u8>> for UpdateValue {
188 fn from(value: Vec<u8>) -> Self {
189 Self::Bytes(value)
190 }
191}
192
193impl From<&[u8]> for UpdateValue {
194 fn from(value: &[u8]) -> Self {
195 Self::Bytes(value.to_vec())
196 }
197}
198
199impl From<chrono::DateTime<chrono::Utc>> for UpdateValue {
200 fn from(value: chrono::DateTime<chrono::Utc>) -> Self {
201 Self::DateTime(value)
202 }
203}
204
205impl From<Uuid> for UpdateValue {
206 fn from(value: Uuid) -> Self {
207 Self::Uuid(value)
208 }
209}
210
211impl<T> From<Option<T>> for UpdateValue
212where
213 T: Into<UpdateValue>,
214{
215 fn from(value: Option<T>) -> Self {
216 match value {
217 Some(value) => value.into(),
218 None => Self::Null,
219 }
220 }
221}
222
223impl UpdateScalar for i64 {
224 fn into_update_value(self) -> UpdateValue {
225 self.into()
226 }
227}
228
229impl UpdateScalar for &str {
230 fn into_update_value(self) -> UpdateValue {
231 self.into()
232 }
233}
234
235impl UpdateScalar for bool {
236 fn into_update_value(self) -> UpdateValue {
237 self.into()
238 }
239}
240
241impl UpdateScalar for f64 {
242 fn into_update_value(self) -> UpdateValue {
243 self.into()
244 }
245}
246
247impl UpdateScalar for Decimal {
248 fn into_update_value(self) -> UpdateValue {
249 self.into()
250 }
251}
252
253impl UpdateScalar for Vec<u8> {
254 fn into_update_value(self) -> UpdateValue {
255 self.into()
256 }
257}
258
259impl UpdateScalar for &[u8] {
260 fn into_update_value(self) -> UpdateValue {
261 self.into()
262 }
263}
264
265impl UpdateScalar for chrono::DateTime<chrono::Utc> {
266 fn into_update_value(self) -> UpdateValue {
267 self.into()
268 }
269}
270
271impl UpdateScalar for Uuid {
272 fn into_update_value(self) -> UpdateValue {
273 self.into()
274 }
275}
276
277impl<T> UpdateScalar for T
278where
279 T: StringValueType,
280{
281 fn into_update_value(self) -> UpdateValue {
282 UpdateValue::String(self.into_db_string())
283 }
284}
285
286impl<T> UpdateScalar for Option<T>
287where
288 T: UpdateScalar,
289{
290 fn into_update_value(self) -> UpdateValue {
291 match self {
292 Some(value) => value.into_update_value(),
293 None => UpdateValue::Null,
294 }
295 }
296}
297
298#[derive(Clone, Debug)]
300pub struct UpdateMany<S, T, V = ()> {
301 values: UpdateValues,
302 variables: QueryVariables,
303 _marker: PhantomData<(S, T, V)>,
304}
305
306impl<S, T> UpdateMany<S, T, ()>
307where
308 T: UpdateManyModel<Variables = ()>,
309{
310 pub fn new(values: T::Values) -> Self {
311 Self {
312 values: values.into_update_values(),
313 variables: QueryVariables::new(),
314 _marker: PhantomData,
315 }
316 }
317
318 pub fn with_values(values: UpdateValues) -> Self {
319 Self {
320 values,
321 variables: QueryVariables::new(),
322 _marker: PhantomData,
323 }
324 }
325}
326
327impl<S, T> UpdateMany<S, T, ()>
328where
329 T: UpdateManyModel,
330{
331 pub fn new_with_variables(
332 variables: T::Variables,
333 values: T::Values,
334 ) -> UpdateMany<S, T, T::Variables> {
335 UpdateMany {
336 values: values.into_update_values(),
337 variables: variables.into_query_variables(),
338 _marker: PhantomData,
339 }
340 }
341
342 pub fn with_values_and_variables(
343 values: UpdateValues,
344 variables: T::Variables,
345 ) -> UpdateMany<S, T, T::Variables> {
346 UpdateMany {
347 values,
348 variables: variables.into_query_variables(),
349 _marker: PhantomData,
350 }
351 }
352
353 pub fn with_variables(self, variables: T::Variables) -> UpdateMany<S, T, T::Variables> {
354 UpdateMany {
355 values: self.values,
356 variables: variables.into_query_variables(),
357 _marker: PhantomData,
358 }
359 }
360}
361
362impl<S, T, V> UpdateMany<S, T, V>
363where
364 S: SchemaAccess,
365 T: UpdateManyModel<Schema = S, Variables = V>,
366 V: QueryVariableSet,
367{
368 fn filter(&self) -> Option<QueryFilter> {
369 T::filter_with_variables(&self.variables)
370 }
371
372 pub fn values(&self) -> &UpdateValues {
373 &self.values
374 }
375
376 pub fn to_sql(&self) -> Result<String, sqlx::Error> {
377 let filter = self.filter();
378 let (sql, _) = build_update_many_sql(
379 S::schema(),
380 T::model_name(),
381 &self.values,
382 filter.as_ref(),
383 &self.variables,
384 )?;
385 Ok(sql)
386 }
387}
388
389impl<S, T, V> UpdateSpec for UpdateMany<S, T, V>
390where
391 S: SchemaAccess,
392 T: UpdateManyModel<Schema = S, Variables = V> + Sync,
393 V: QueryVariableSet + Sync,
394{
395 type Output = u64;
396
397 fn execute<'a>(
398 &'a self,
399 executor: &'a dyn PgExecutor,
400 ) -> BoxFuture<'a, Result<Self::Output, sqlx::Error>> {
401 Box::pin(async move {
402 let filter = self.filter();
403 let (sql, bindings) = build_update_many_sql(
404 S::schema(),
405 T::model_name(),
406 &self.values,
407 filter.as_ref(),
408 &self.variables,
409 )?;
410 let result = executor
411 .execute(bind_update(sqlx::query(&sql), &bindings))
412 .await?;
413 Ok(result.rows_affected())
414 })
415 }
416}
417
418fn build_update_many_sql(
419 schema: &Schema,
420 model_name: &str,
421 values: &UpdateValues,
422 filter: Option<&QueryFilter>,
423 variables: &QueryVariables,
424) -> Result<(String, Vec<BoundValue>), sqlx::Error> {
425 let model = resolve_schema_model(schema, model_name, "update")?;
426
427 validate_update_values(model, values)?;
428
429 let ordered_values = ordered_update_values(model, values);
430 let mut builder = UpdateSqlBuilder {
431 schema,
432 variables,
433 bindings: Vec::new(),
434 next_alias: 1,
435 };
436
437 let assignments = ordered_values
438 .iter()
439 .map(|(field, value)| {
440 let scalar = match field.ty() {
441 FieldType::Scalar(scalar) => scalar.scalar(),
442 FieldType::Relation { .. } => {
443 return Err(schema_error(format!(
444 "field `{}.{}` is not scalar and cannot appear in `data`",
445 model.name(),
446 field.name()
447 )));
448 }
449 };
450 let placeholder =
451 builder.push_update_binding((*value).clone(), scalar, field.has_db_uuid())?;
452 Ok(format!(
453 r#"{} = {}"#,
454 quoted_ident(field.name()),
455 placeholder
456 ))
457 })
458 .collect::<Result<Vec<_>, sqlx::Error>>()?;
459
460 let where_clause = filter
461 .map(|filter| builder.filter_sql(model, filter, "t0"))
462 .transpose()?;
463
464 let sql = format!(
465 r#"UPDATE {} AS "t0" SET {}{}"#,
466 quoted_ident(model.name()),
467 assignments.join(", "),
468 where_clause
469 .map(|where_clause| format!(" WHERE {where_clause}"))
470 .unwrap_or_default(),
471 );
472
473 Ok((sql, builder.bindings))
474}
475
476fn validate_update_values(model: &Model, values: &UpdateValues) -> Result<(), sqlx::Error> {
477 if values.is_empty() {
478 return Err(schema_error(format!(
479 "update on model `{}` must write at least one scalar field",
480 model.name()
481 )));
482 }
483
484 for provided in values.iter() {
485 let field = model.field_named(&provided.name).ok_or_else(|| {
486 schema_error(format!(
487 "unknown field `{}` in update for model `{}`",
488 provided.name,
489 model.name()
490 ))
491 })?;
492
493 if field.kind().is_relation() {
494 return Err(schema_error(format!(
495 "relation field `{}` cannot be written in update for model `{}`",
496 field.name(),
497 model.name()
498 )));
499 }
500
501 if !update_value_matches_field(&provided.value, field) {
502 return Err(schema_error(format!(
503 "update value for field `{}` is incompatible with schema type `{}` on model `{}`",
504 field.name(),
505 field.ty().name(),
506 model.name()
507 )));
508 }
509 }
510
511 Ok(())
512}
513
514fn ordered_update_values<'a>(
515 model: &'a Model,
516 values: &'a UpdateValues,
517) -> Vec<(&'a Field, &'a UpdateValue)> {
518 let mut ordered = Vec::new();
519
520 for field in model.fields() {
521 if field.kind().is_relation() {
522 continue;
523 }
524
525 if let Some(value) = values.get(field.name()) {
526 ordered.push((field, value));
527 }
528 }
529
530 ordered
531}
532
533struct UpdateSqlBuilder<'a> {
534 schema: &'a Schema,
535 variables: &'a QueryVariables,
536 bindings: Vec<BoundValue>,
537 next_alias: usize,
538}
539
540impl<'a> UpdateSqlBuilder<'a> {
541 fn filter_sql(
542 &mut self,
543 model: &'a Model,
544 filter: &QueryFilter,
545 table_alias: &str,
546 ) -> Result<String, sqlx::Error> {
547 compile_filter_sql(self, model, filter, table_alias)
548 }
549
550 fn push_update_binding(
551 &mut self,
552 value: UpdateValue,
553 scalar: ScalarType,
554 is_db_uuid: bool,
555 ) -> Result<String, sqlx::Error> {
556 let binding = match (value, scalar, is_db_uuid) {
557 (UpdateValue::Null, ScalarType::String, false) => BoundValue::NullString,
558 (UpdateValue::Null, ScalarType::Boolean, _) => BoundValue::NullBool,
559 (UpdateValue::Null, ScalarType::Float, _) => BoundValue::NullFloat,
560 (UpdateValue::Null, ScalarType::Decimal, _) => BoundValue::NullDecimal,
561 (UpdateValue::Null, ScalarType::Bytes, _) => BoundValue::NullBytes,
562 (UpdateValue::Null, ScalarType::DateTime, _) => BoundValue::NullDateTime,
563 (UpdateValue::Null, ScalarType::String, true) => BoundValue::NullUuid,
564 (value, _, _) => value.into(),
565 };
566
567 self.bindings.push(binding);
568 Ok(format!("${}", self.bindings.len()))
569 }
570
571 fn push_query_binding(
572 &mut self,
573 value: QueryVariableValue,
574 _scalar: ScalarType,
575 ) -> Result<String, sqlx::Error> {
576 self.bindings.push(value.into());
577 Ok(format!("${}", self.bindings.len()))
578 }
579}
580
581impl<'a> FilterBuilder<'a> for UpdateSqlBuilder<'a> {
582 fn schema(&self) -> &'a Schema {
583 self.schema
584 }
585
586 fn variables(&self) -> &'a QueryVariables {
587 self.variables
588 }
589
590 fn push_filter_binding(
591 &mut self,
592 value: QueryVariableValue,
593 scalar: ScalarType,
594 ) -> Result<String, sqlx::Error> {
595 self.push_query_binding(value, scalar)
596 }
597
598 fn next_filter_alias(&mut self) -> String {
599 let alias = format!("t{}", self.next_alias);
600 self.next_alias += 1;
601 alias
602 }
603
604 fn operation_name(&self) -> &'static str {
605 "update"
606 }
607}
608
609#[derive(Clone, Debug, PartialEq)]
610enum BoundValue {
611 Null,
612 NullString,
613 NullBool,
614 NullFloat,
615 NullDecimal,
616 NullBytes,
617 NullDateTime,
618 NullUuid,
619 Int(i64),
620 String(String),
621 Bool(bool),
622 Float(f64),
623 Decimal(Decimal),
624 Bytes(Vec<u8>),
625 DateTime(chrono::DateTime<chrono::Utc>),
626 Uuid(Uuid),
627 List(Vec<QueryVariableValue>),
628}
629
630impl From<UpdateValue> for BoundValue {
631 fn from(value: UpdateValue) -> Self {
632 match value {
633 UpdateValue::Null => Self::Null,
634 UpdateValue::Int(value) => Self::Int(value),
635 UpdateValue::String(value) => Self::String(value),
636 UpdateValue::Bool(value) => Self::Bool(value),
637 UpdateValue::Float(value) => Self::Float(value),
638 UpdateValue::Decimal(value) => Self::Decimal(value),
639 UpdateValue::Bytes(value) => Self::Bytes(value),
640 UpdateValue::DateTime(value) => Self::DateTime(value),
641 UpdateValue::Uuid(value) => Self::Uuid(value),
642 }
643 }
644}
645
646impl From<QueryVariableValue> for BoundValue {
647 fn from(value: QueryVariableValue) -> Self {
648 match value {
649 QueryVariableValue::Null => Self::Null,
650 QueryVariableValue::Int(value) => Self::Int(value),
651 QueryVariableValue::String(value) => Self::String(value),
652 QueryVariableValue::Bool(value) => Self::Bool(value),
653 QueryVariableValue::Float(value) => Self::Float(value),
654 QueryVariableValue::Decimal(value) => Self::Decimal(value),
655 QueryVariableValue::Bytes(value) => Self::Bytes(value),
656 QueryVariableValue::DateTime(value) => Self::DateTime(value),
657 QueryVariableValue::Uuid(value) => Self::Uuid(value),
658 QueryVariableValue::List(values) => Self::List(values),
659 }
660 }
661}
662
663fn update_value_matches_field(value: &UpdateValue, field: &Field) -> bool {
664 let FieldType::Scalar(scalar) = field.ty() else {
665 return false;
666 };
667
668 match value {
669 UpdateValue::Null => scalar.optional(),
670 UpdateValue::Int(_) => {
671 matches!(scalar.scalar(), ScalarType::Int | ScalarType::BigInt)
672 }
673 UpdateValue::String(_) => scalar.scalar() == ScalarType::String && !field.has_db_uuid(),
674 UpdateValue::Bool(_) => scalar.scalar() == ScalarType::Boolean,
675 UpdateValue::Float(_) => scalar.scalar() == ScalarType::Float,
676 UpdateValue::Decimal(_) => scalar.scalar() == ScalarType::Decimal,
677 UpdateValue::Bytes(_) => scalar.scalar() == ScalarType::Bytes,
678 UpdateValue::DateTime(_) => scalar.scalar() == ScalarType::DateTime,
679 UpdateValue::Uuid(_) => scalar.scalar() == ScalarType::String && field.has_db_uuid(),
680 }
681}
682
683fn bind_update<'q>(
684 mut query: SqlxQuery<'q, Postgres, PgArguments>,
685 bindings: &'q [BoundValue],
686) -> SqlxQuery<'q, Postgres, PgArguments> {
687 for binding in bindings {
688 query = match binding {
689 BoundValue::Null => query.bind(Option::<i64>::None),
690 BoundValue::NullString => query.bind(Option::<String>::None),
691 BoundValue::NullBool => query.bind(Option::<bool>::None),
692 BoundValue::NullFloat => query.bind(Option::<f64>::None),
693 BoundValue::NullDecimal => query.bind(Option::<Decimal>::None),
694 BoundValue::NullBytes => query.bind(Option::<Vec<u8>>::None),
695 BoundValue::NullDateTime => query.bind(Option::<chrono::DateTime<chrono::Utc>>::None),
696 BoundValue::NullUuid => query.bind(Option::<Uuid>::None),
697 BoundValue::Int(value) => query.bind(*value),
698 BoundValue::String(value) => query.bind(value),
699 BoundValue::Bool(value) => query.bind(*value),
700 BoundValue::Float(value) => query.bind(*value),
701 BoundValue::Decimal(value) => query.bind(*value),
702 BoundValue::Bytes(value) => query.bind(value),
703 BoundValue::DateTime(value) => query.bind(*value),
704 BoundValue::Uuid(value) => query.bind(*value),
705 BoundValue::List(values) => {
706 let first = values
707 .first()
708 .expect("list-valued query variables must not be empty when bound");
709
710 match first {
711 QueryVariableValue::Null => {
712 unreachable!("list-valued query variables must not contain null items")
713 }
714 QueryVariableValue::Int(_) => query.bind(
715 values
716 .iter()
717 .map(|value| match value {
718 QueryVariableValue::Int(value) => *value,
719 _ => unreachable!("list-valued query variables must be homogenous"),
720 })
721 .collect::<Vec<_>>(),
722 ),
723 QueryVariableValue::String(_) => query.bind(
724 values
725 .iter()
726 .map(|value| match value {
727 QueryVariableValue::String(value) => value.clone(),
728 _ => unreachable!("list-valued query variables must be homogenous"),
729 })
730 .collect::<Vec<_>>(),
731 ),
732 QueryVariableValue::Bool(_) => query.bind(
733 values
734 .iter()
735 .map(|value| match value {
736 QueryVariableValue::Bool(value) => *value,
737 _ => unreachable!("list-valued query variables must be homogenous"),
738 })
739 .collect::<Vec<_>>(),
740 ),
741 QueryVariableValue::Float(_) => query.bind(
742 values
743 .iter()
744 .map(|value| match value {
745 QueryVariableValue::Float(value) => *value,
746 _ => unreachable!("list-valued query variables must be homogenous"),
747 })
748 .collect::<Vec<_>>(),
749 ),
750 QueryVariableValue::Decimal(_) => query.bind(
751 values
752 .iter()
753 .map(|value| match value {
754 QueryVariableValue::Decimal(value) => *value,
755 _ => unreachable!("list-valued query variables must be homogenous"),
756 })
757 .collect::<Vec<_>>(),
758 ),
759 QueryVariableValue::Bytes(_) => query.bind(
760 values
761 .iter()
762 .map(|value| match value {
763 QueryVariableValue::Bytes(value) => value.clone(),
764 _ => unreachable!("list-valued query variables must be homogenous"),
765 })
766 .collect::<Vec<_>>(),
767 ),
768 QueryVariableValue::DateTime(_) => query.bind(
769 values
770 .iter()
771 .map(|value| match value {
772 QueryVariableValue::DateTime(value) => *value,
773 _ => unreachable!("list-valued query variables must be homogenous"),
774 })
775 .collect::<Vec<_>>(),
776 ),
777 QueryVariableValue::Uuid(_) => query.bind(
778 values
779 .iter()
780 .map(|value| match value {
781 QueryVariableValue::Uuid(value) => *value,
782 _ => unreachable!("list-valued query variables must be homogenous"),
783 })
784 .collect::<Vec<_>>(),
785 ),
786 QueryVariableValue::List(_) => {
787 unreachable!("list-valued query variables must not contain nested lists")
788 }
789 }
790 }
791 };
792 }
793
794 query
795}