Skip to main content

vitrail_pg_core/
update.rs

1use std::collections::HashMap;
2use std::marker::PhantomData;
3
4use rust_decimal::Decimal;
5use sqlx::postgres::PgArguments;
6use sqlx::{Postgres, query::Query as SqlxQuery};
7use uuid::Uuid;
8
9use crate::PgExecutor;
10use crate::filter::{FilterBuilder, compile_filter_sql, schema_model as resolve_schema_model};
11use crate::query::{
12    BoxFuture, QueryFilter, QueryVariableSet, QueryVariableValue, QueryVariables, SchemaAccess,
13    StringValueType, quoted_ident, schema_error,
14};
15use crate::schema::{Field, FieldType, Model, ScalarType, Schema};
16
17/// Runtime contract implemented by executable update values.
18pub trait UpdateSpec: Send + Sync {
19    type Output: Send + 'static;
20
21    #[doc(hidden)]
22    fn execute<'a>(
23        &'a self,
24        executor: &'a dyn PgExecutor,
25    ) -> BoxFuture<'a, Result<Self::Output, sqlx::Error>>;
26}
27
28/// Runtime contract implemented by bulk update models.
29pub trait UpdateManyModel: Sized + Send + 'static {
30    type Schema: SchemaAccess;
31    type Values: UpdateValueSet;
32    type Variables: QueryVariableSet;
33
34    fn model_name() -> &'static str;
35
36    fn filter() -> Option<QueryFilter> {
37        None
38    }
39
40    fn filter_with_variables(_variables: &QueryVariables) -> Option<QueryFilter> {
41        Self::filter()
42    }
43}
44
45/// Converts a user-provided input into executable update values.
46pub trait UpdateValueSet: Send + 'static {
47    fn into_update_values(self) -> UpdateValues;
48}
49
50impl UpdateValueSet for UpdateValues {
51    fn into_update_values(self) -> UpdateValues {
52        self
53    }
54}
55
56impl UpdateValueSet for () {
57    fn into_update_values(self) -> UpdateValues {
58        UpdateValues::new()
59    }
60}
61
62pub trait UpdateScalar: Send {
63    fn into_update_value(self) -> UpdateValue;
64}
65
66#[derive(Clone, Debug, Default, PartialEq)]
67pub struct UpdateValues {
68    values: Vec<UpdateFieldValue>,
69    value_indices: HashMap<String, usize>,
70}
71
72impl UpdateValues {
73    pub fn new() -> Self {
74        Self {
75            values: Vec::new(),
76            value_indices: HashMap::new(),
77        }
78    }
79
80    pub fn from_values(values: Vec<(impl Into<String>, UpdateValue)>) -> Self {
81        let mut update_values = Self::new();
82
83        for (name, value) in values {
84            update_values
85                .push(name, value)
86                .expect("update field names must be unique");
87        }
88
89        update_values
90    }
91
92    pub fn push(
93        &mut self,
94        name: impl Into<String>,
95        value: UpdateValue,
96    ) -> Result<usize, sqlx::Error> {
97        let name = name.into();
98
99        if self.value_indices.contains_key(&name) {
100            return Err(schema_error(format!("duplicate update field `{name}`")));
101        }
102
103        let index = self.values.len();
104        self.values.push(UpdateFieldValue {
105            name: name.clone(),
106            value,
107        });
108        self.value_indices.insert(name, index);
109        Ok(index)
110    }
111
112    pub fn get(&self, name: &str) -> Option<&UpdateValue> {
113        self.value_indices
114            .get(name)
115            .and_then(|index| self.values.get(*index))
116            .map(|field| &field.value)
117    }
118
119    pub fn iter(&self) -> impl Iterator<Item = &UpdateFieldValue> {
120        self.values.iter()
121    }
122
123    pub fn len(&self) -> usize {
124        self.values.len()
125    }
126
127    pub fn is_empty(&self) -> bool {
128        self.values.is_empty()
129    }
130}
131
132#[derive(Clone, Debug, PartialEq)]
133pub struct UpdateFieldValue {
134    pub name: String,
135    pub value: UpdateValue,
136}
137
138#[derive(Clone, Debug, PartialEq)]
139pub enum UpdateValue {
140    Null,
141    Int(i64),
142    String(String),
143    Bool(bool),
144    Float(f64),
145    Decimal(Decimal),
146    Bytes(Vec<u8>),
147    DateTime(chrono::DateTime<chrono::Utc>),
148    Uuid(Uuid),
149}
150
151impl From<i64> for UpdateValue {
152    fn from(value: i64) -> Self {
153        Self::Int(value)
154    }
155}
156
157impl From<String> for UpdateValue {
158    fn from(value: String) -> Self {
159        Self::String(value)
160    }
161}
162
163impl From<&str> for UpdateValue {
164    fn from(value: &str) -> Self {
165        Self::String(value.to_owned())
166    }
167}
168
169impl From<bool> for UpdateValue {
170    fn from(value: bool) -> Self {
171        Self::Bool(value)
172    }
173}
174
175impl From<f64> for UpdateValue {
176    fn from(value: f64) -> Self {
177        Self::Float(value)
178    }
179}
180
181impl From<Decimal> for UpdateValue {
182    fn from(value: Decimal) -> Self {
183        Self::Decimal(value)
184    }
185}
186
187impl From<Vec<u8>> for UpdateValue {
188    fn from(value: Vec<u8>) -> Self {
189        Self::Bytes(value)
190    }
191}
192
193impl From<&[u8]> for UpdateValue {
194    fn from(value: &[u8]) -> Self {
195        Self::Bytes(value.to_vec())
196    }
197}
198
199impl From<chrono::DateTime<chrono::Utc>> for UpdateValue {
200    fn from(value: chrono::DateTime<chrono::Utc>) -> Self {
201        Self::DateTime(value)
202    }
203}
204
205impl From<Uuid> for UpdateValue {
206    fn from(value: Uuid) -> Self {
207        Self::Uuid(value)
208    }
209}
210
211impl<T> From<Option<T>> for UpdateValue
212where
213    T: Into<UpdateValue>,
214{
215    fn from(value: Option<T>) -> Self {
216        match value {
217            Some(value) => value.into(),
218            None => Self::Null,
219        }
220    }
221}
222
223impl UpdateScalar for i64 {
224    fn into_update_value(self) -> UpdateValue {
225        self.into()
226    }
227}
228
229impl UpdateScalar for &str {
230    fn into_update_value(self) -> UpdateValue {
231        self.into()
232    }
233}
234
235impl UpdateScalar for bool {
236    fn into_update_value(self) -> UpdateValue {
237        self.into()
238    }
239}
240
241impl UpdateScalar for f64 {
242    fn into_update_value(self) -> UpdateValue {
243        self.into()
244    }
245}
246
247impl UpdateScalar for Decimal {
248    fn into_update_value(self) -> UpdateValue {
249        self.into()
250    }
251}
252
253impl UpdateScalar for Vec<u8> {
254    fn into_update_value(self) -> UpdateValue {
255        self.into()
256    }
257}
258
259impl UpdateScalar for &[u8] {
260    fn into_update_value(self) -> UpdateValue {
261        self.into()
262    }
263}
264
265impl UpdateScalar for chrono::DateTime<chrono::Utc> {
266    fn into_update_value(self) -> UpdateValue {
267        self.into()
268    }
269}
270
271impl UpdateScalar for Uuid {
272    fn into_update_value(self) -> UpdateValue {
273        self.into()
274    }
275}
276
277impl<T> UpdateScalar for T
278where
279    T: StringValueType,
280{
281    fn into_update_value(self) -> UpdateValue {
282        UpdateValue::String(self.into_db_string())
283    }
284}
285
286impl<T> UpdateScalar for Option<T>
287where
288    T: UpdateScalar,
289{
290    fn into_update_value(self) -> UpdateValue {
291        match self {
292            Some(value) => value.into_update_value(),
293            None => UpdateValue::Null,
294        }
295    }
296}
297
298/// Executable scalar bulk update returning the number of affected rows.
299#[derive(Clone, Debug)]
300pub struct UpdateMany<S, T, V = ()> {
301    values: UpdateValues,
302    variables: QueryVariables,
303    _marker: PhantomData<(S, T, V)>,
304}
305
306impl<S, T> UpdateMany<S, T, ()>
307where
308    T: UpdateManyModel<Variables = ()>,
309{
310    pub fn new(values: T::Values) -> Self {
311        Self {
312            values: values.into_update_values(),
313            variables: QueryVariables::new(),
314            _marker: PhantomData,
315        }
316    }
317
318    pub fn with_values(values: UpdateValues) -> Self {
319        Self {
320            values,
321            variables: QueryVariables::new(),
322            _marker: PhantomData,
323        }
324    }
325}
326
327impl<S, T> UpdateMany<S, T, ()>
328where
329    T: UpdateManyModel,
330{
331    pub fn new_with_variables(
332        variables: T::Variables,
333        values: T::Values,
334    ) -> UpdateMany<S, T, T::Variables> {
335        UpdateMany {
336            values: values.into_update_values(),
337            variables: variables.into_query_variables(),
338            _marker: PhantomData,
339        }
340    }
341
342    pub fn with_values_and_variables(
343        values: UpdateValues,
344        variables: T::Variables,
345    ) -> UpdateMany<S, T, T::Variables> {
346        UpdateMany {
347            values,
348            variables: variables.into_query_variables(),
349            _marker: PhantomData,
350        }
351    }
352
353    pub fn with_variables(self, variables: T::Variables) -> UpdateMany<S, T, T::Variables> {
354        UpdateMany {
355            values: self.values,
356            variables: variables.into_query_variables(),
357            _marker: PhantomData,
358        }
359    }
360}
361
362impl<S, T, V> UpdateMany<S, T, V>
363where
364    S: SchemaAccess,
365    T: UpdateManyModel<Schema = S, Variables = V>,
366    V: QueryVariableSet,
367{
368    fn filter(&self) -> Option<QueryFilter> {
369        T::filter_with_variables(&self.variables)
370    }
371
372    pub fn values(&self) -> &UpdateValues {
373        &self.values
374    }
375
376    pub fn to_sql(&self) -> Result<String, sqlx::Error> {
377        let filter = self.filter();
378        let (sql, _) = build_update_many_sql(
379            S::schema(),
380            T::model_name(),
381            &self.values,
382            filter.as_ref(),
383            &self.variables,
384        )?;
385        Ok(sql)
386    }
387}
388
389impl<S, T, V> UpdateSpec for UpdateMany<S, T, V>
390where
391    S: SchemaAccess,
392    T: UpdateManyModel<Schema = S, Variables = V> + Sync,
393    V: QueryVariableSet + Sync,
394{
395    type Output = u64;
396
397    fn execute<'a>(
398        &'a self,
399        executor: &'a dyn PgExecutor,
400    ) -> BoxFuture<'a, Result<Self::Output, sqlx::Error>> {
401        Box::pin(async move {
402            let filter = self.filter();
403            let (sql, bindings) = build_update_many_sql(
404                S::schema(),
405                T::model_name(),
406                &self.values,
407                filter.as_ref(),
408                &self.variables,
409            )?;
410            let result = executor
411                .execute(bind_update(sqlx::query(&sql), &bindings))
412                .await?;
413            Ok(result.rows_affected())
414        })
415    }
416}
417
418fn build_update_many_sql(
419    schema: &Schema,
420    model_name: &str,
421    values: &UpdateValues,
422    filter: Option<&QueryFilter>,
423    variables: &QueryVariables,
424) -> Result<(String, Vec<BoundValue>), sqlx::Error> {
425    let model = resolve_schema_model(schema, model_name, "update")?;
426
427    validate_update_values(model, values)?;
428
429    let ordered_values = ordered_update_values(model, values);
430    let mut builder = UpdateSqlBuilder {
431        schema,
432        variables,
433        bindings: Vec::new(),
434        next_alias: 1,
435    };
436
437    let assignments = ordered_values
438        .iter()
439        .map(|(field, value)| {
440            let scalar = match field.ty() {
441                FieldType::Scalar(scalar) => scalar.scalar(),
442                FieldType::Relation { .. } => {
443                    return Err(schema_error(format!(
444                        "field `{}.{}` is not scalar and cannot appear in `data`",
445                        model.name(),
446                        field.name()
447                    )));
448                }
449            };
450            let placeholder =
451                builder.push_update_binding((*value).clone(), scalar, field.has_db_uuid())?;
452            Ok(format!(
453                r#"{} = {}"#,
454                quoted_ident(field.name()),
455                placeholder
456            ))
457        })
458        .collect::<Result<Vec<_>, sqlx::Error>>()?;
459
460    let where_clause = filter
461        .map(|filter| builder.filter_sql(model, filter, "t0"))
462        .transpose()?;
463
464    let sql = format!(
465        r#"UPDATE {} AS "t0" SET {}{}"#,
466        quoted_ident(model.name()),
467        assignments.join(", "),
468        where_clause
469            .map(|where_clause| format!(" WHERE {where_clause}"))
470            .unwrap_or_default(),
471    );
472
473    Ok((sql, builder.bindings))
474}
475
476fn validate_update_values(model: &Model, values: &UpdateValues) -> Result<(), sqlx::Error> {
477    if values.is_empty() {
478        return Err(schema_error(format!(
479            "update on model `{}` must write at least one scalar field",
480            model.name()
481        )));
482    }
483
484    for provided in values.iter() {
485        let field = model.field_named(&provided.name).ok_or_else(|| {
486            schema_error(format!(
487                "unknown field `{}` in update for model `{}`",
488                provided.name,
489                model.name()
490            ))
491        })?;
492
493        if field.kind().is_relation() {
494            return Err(schema_error(format!(
495                "relation field `{}` cannot be written in update for model `{}`",
496                field.name(),
497                model.name()
498            )));
499        }
500
501        if !update_value_matches_field(&provided.value, field) {
502            return Err(schema_error(format!(
503                "update value for field `{}` is incompatible with schema type `{}` on model `{}`",
504                field.name(),
505                field.ty().name(),
506                model.name()
507            )));
508        }
509    }
510
511    Ok(())
512}
513
514fn ordered_update_values<'a>(
515    model: &'a Model,
516    values: &'a UpdateValues,
517) -> Vec<(&'a Field, &'a UpdateValue)> {
518    let mut ordered = Vec::new();
519
520    for field in model.fields() {
521        if field.kind().is_relation() {
522            continue;
523        }
524
525        if let Some(value) = values.get(field.name()) {
526            ordered.push((field, value));
527        }
528    }
529
530    ordered
531}
532
533struct UpdateSqlBuilder<'a> {
534    schema: &'a Schema,
535    variables: &'a QueryVariables,
536    bindings: Vec<BoundValue>,
537    next_alias: usize,
538}
539
540impl<'a> UpdateSqlBuilder<'a> {
541    fn filter_sql(
542        &mut self,
543        model: &'a Model,
544        filter: &QueryFilter,
545        table_alias: &str,
546    ) -> Result<String, sqlx::Error> {
547        compile_filter_sql(self, model, filter, table_alias)
548    }
549
550    fn push_update_binding(
551        &mut self,
552        value: UpdateValue,
553        scalar: ScalarType,
554        is_db_uuid: bool,
555    ) -> Result<String, sqlx::Error> {
556        let binding = match (value, scalar, is_db_uuid) {
557            (UpdateValue::Null, ScalarType::String, false) => BoundValue::NullString,
558            (UpdateValue::Null, ScalarType::Boolean, _) => BoundValue::NullBool,
559            (UpdateValue::Null, ScalarType::Float, _) => BoundValue::NullFloat,
560            (UpdateValue::Null, ScalarType::Decimal, _) => BoundValue::NullDecimal,
561            (UpdateValue::Null, ScalarType::Bytes, _) => BoundValue::NullBytes,
562            (UpdateValue::Null, ScalarType::DateTime, _) => BoundValue::NullDateTime,
563            (UpdateValue::Null, ScalarType::String, true) => BoundValue::NullUuid,
564            (value, _, _) => value.into(),
565        };
566
567        self.bindings.push(binding);
568        Ok(format!("${}", self.bindings.len()))
569    }
570
571    fn push_query_binding(
572        &mut self,
573        value: QueryVariableValue,
574        _scalar: ScalarType,
575    ) -> Result<String, sqlx::Error> {
576        self.bindings.push(value.into());
577        Ok(format!("${}", self.bindings.len()))
578    }
579}
580
581impl<'a> FilterBuilder<'a> for UpdateSqlBuilder<'a> {
582    fn schema(&self) -> &'a Schema {
583        self.schema
584    }
585
586    fn variables(&self) -> &'a QueryVariables {
587        self.variables
588    }
589
590    fn push_filter_binding(
591        &mut self,
592        value: QueryVariableValue,
593        scalar: ScalarType,
594    ) -> Result<String, sqlx::Error> {
595        self.push_query_binding(value, scalar)
596    }
597
598    fn next_filter_alias(&mut self) -> String {
599        let alias = format!("t{}", self.next_alias);
600        self.next_alias += 1;
601        alias
602    }
603
604    fn operation_name(&self) -> &'static str {
605        "update"
606    }
607}
608
609#[derive(Clone, Debug, PartialEq)]
610enum BoundValue {
611    Null,
612    NullString,
613    NullBool,
614    NullFloat,
615    NullDecimal,
616    NullBytes,
617    NullDateTime,
618    NullUuid,
619    Int(i64),
620    String(String),
621    Bool(bool),
622    Float(f64),
623    Decimal(Decimal),
624    Bytes(Vec<u8>),
625    DateTime(chrono::DateTime<chrono::Utc>),
626    Uuid(Uuid),
627    List(Vec<QueryVariableValue>),
628}
629
630impl From<UpdateValue> for BoundValue {
631    fn from(value: UpdateValue) -> Self {
632        match value {
633            UpdateValue::Null => Self::Null,
634            UpdateValue::Int(value) => Self::Int(value),
635            UpdateValue::String(value) => Self::String(value),
636            UpdateValue::Bool(value) => Self::Bool(value),
637            UpdateValue::Float(value) => Self::Float(value),
638            UpdateValue::Decimal(value) => Self::Decimal(value),
639            UpdateValue::Bytes(value) => Self::Bytes(value),
640            UpdateValue::DateTime(value) => Self::DateTime(value),
641            UpdateValue::Uuid(value) => Self::Uuid(value),
642        }
643    }
644}
645
646impl From<QueryVariableValue> for BoundValue {
647    fn from(value: QueryVariableValue) -> Self {
648        match value {
649            QueryVariableValue::Null => Self::Null,
650            QueryVariableValue::Int(value) => Self::Int(value),
651            QueryVariableValue::String(value) => Self::String(value),
652            QueryVariableValue::Bool(value) => Self::Bool(value),
653            QueryVariableValue::Float(value) => Self::Float(value),
654            QueryVariableValue::Decimal(value) => Self::Decimal(value),
655            QueryVariableValue::Bytes(value) => Self::Bytes(value),
656            QueryVariableValue::DateTime(value) => Self::DateTime(value),
657            QueryVariableValue::Uuid(value) => Self::Uuid(value),
658            QueryVariableValue::List(values) => Self::List(values),
659        }
660    }
661}
662
663fn update_value_matches_field(value: &UpdateValue, field: &Field) -> bool {
664    let FieldType::Scalar(scalar) = field.ty() else {
665        return false;
666    };
667
668    match value {
669        UpdateValue::Null => scalar.optional(),
670        UpdateValue::Int(_) => {
671            matches!(scalar.scalar(), ScalarType::Int | ScalarType::BigInt)
672        }
673        UpdateValue::String(_) => scalar.scalar() == ScalarType::String && !field.has_db_uuid(),
674        UpdateValue::Bool(_) => scalar.scalar() == ScalarType::Boolean,
675        UpdateValue::Float(_) => scalar.scalar() == ScalarType::Float,
676        UpdateValue::Decimal(_) => scalar.scalar() == ScalarType::Decimal,
677        UpdateValue::Bytes(_) => scalar.scalar() == ScalarType::Bytes,
678        UpdateValue::DateTime(_) => scalar.scalar() == ScalarType::DateTime,
679        UpdateValue::Uuid(_) => scalar.scalar() == ScalarType::String && field.has_db_uuid(),
680    }
681}
682
683fn bind_update<'q>(
684    mut query: SqlxQuery<'q, Postgres, PgArguments>,
685    bindings: &'q [BoundValue],
686) -> SqlxQuery<'q, Postgres, PgArguments> {
687    for binding in bindings {
688        query = match binding {
689            BoundValue::Null => query.bind(Option::<i64>::None),
690            BoundValue::NullString => query.bind(Option::<String>::None),
691            BoundValue::NullBool => query.bind(Option::<bool>::None),
692            BoundValue::NullFloat => query.bind(Option::<f64>::None),
693            BoundValue::NullDecimal => query.bind(Option::<Decimal>::None),
694            BoundValue::NullBytes => query.bind(Option::<Vec<u8>>::None),
695            BoundValue::NullDateTime => query.bind(Option::<chrono::DateTime<chrono::Utc>>::None),
696            BoundValue::NullUuid => query.bind(Option::<Uuid>::None),
697            BoundValue::Int(value) => query.bind(*value),
698            BoundValue::String(value) => query.bind(value),
699            BoundValue::Bool(value) => query.bind(*value),
700            BoundValue::Float(value) => query.bind(*value),
701            BoundValue::Decimal(value) => query.bind(*value),
702            BoundValue::Bytes(value) => query.bind(value),
703            BoundValue::DateTime(value) => query.bind(*value),
704            BoundValue::Uuid(value) => query.bind(*value),
705            BoundValue::List(values) => {
706                let first = values
707                    .first()
708                    .expect("list-valued query variables must not be empty when bound");
709
710                match first {
711                    QueryVariableValue::Null => {
712                        unreachable!("list-valued query variables must not contain null items")
713                    }
714                    QueryVariableValue::Int(_) => query.bind(
715                        values
716                            .iter()
717                            .map(|value| match value {
718                                QueryVariableValue::Int(value) => *value,
719                                _ => unreachable!("list-valued query variables must be homogenous"),
720                            })
721                            .collect::<Vec<_>>(),
722                    ),
723                    QueryVariableValue::String(_) => query.bind(
724                        values
725                            .iter()
726                            .map(|value| match value {
727                                QueryVariableValue::String(value) => value.clone(),
728                                _ => unreachable!("list-valued query variables must be homogenous"),
729                            })
730                            .collect::<Vec<_>>(),
731                    ),
732                    QueryVariableValue::Bool(_) => query.bind(
733                        values
734                            .iter()
735                            .map(|value| match value {
736                                QueryVariableValue::Bool(value) => *value,
737                                _ => unreachable!("list-valued query variables must be homogenous"),
738                            })
739                            .collect::<Vec<_>>(),
740                    ),
741                    QueryVariableValue::Float(_) => query.bind(
742                        values
743                            .iter()
744                            .map(|value| match value {
745                                QueryVariableValue::Float(value) => *value,
746                                _ => unreachable!("list-valued query variables must be homogenous"),
747                            })
748                            .collect::<Vec<_>>(),
749                    ),
750                    QueryVariableValue::Decimal(_) => query.bind(
751                        values
752                            .iter()
753                            .map(|value| match value {
754                                QueryVariableValue::Decimal(value) => *value,
755                                _ => unreachable!("list-valued query variables must be homogenous"),
756                            })
757                            .collect::<Vec<_>>(),
758                    ),
759                    QueryVariableValue::Bytes(_) => query.bind(
760                        values
761                            .iter()
762                            .map(|value| match value {
763                                QueryVariableValue::Bytes(value) => value.clone(),
764                                _ => unreachable!("list-valued query variables must be homogenous"),
765                            })
766                            .collect::<Vec<_>>(),
767                    ),
768                    QueryVariableValue::DateTime(_) => query.bind(
769                        values
770                            .iter()
771                            .map(|value| match value {
772                                QueryVariableValue::DateTime(value) => *value,
773                                _ => unreachable!("list-valued query variables must be homogenous"),
774                            })
775                            .collect::<Vec<_>>(),
776                    ),
777                    QueryVariableValue::Uuid(_) => query.bind(
778                        values
779                            .iter()
780                            .map(|value| match value {
781                                QueryVariableValue::Uuid(value) => *value,
782                                _ => unreachable!("list-valued query variables must be homogenous"),
783                            })
784                            .collect::<Vec<_>>(),
785                    ),
786                    QueryVariableValue::List(_) => {
787                        unreachable!("list-valued query variables must not contain nested lists")
788                    }
789                }
790            }
791        };
792    }
793
794    query
795}