Skip to main content

vitrail_pg_core/
insert.rs

1use std::collections::{HashMap, HashSet};
2use std::marker::PhantomData;
3
4use rust_decimal::Decimal;
5use sqlx::postgres::{PgArguments, PgRow};
6use sqlx::{Postgres, query::Query as SqlxQuery};
7use uuid::Uuid;
8
9use crate::PgExecutor;
10use crate::query::{
11    BoxFuture, SchemaAccess, StringValueType, alias_name, quoted_ident, schema_error, select_expr,
12};
13use crate::schema::{
14    Attribute, DefaultFunction, Field, FieldType, Model, Resolution, ScalarType, Schema,
15};
16
17/// Runtime contract implemented by executable insert values.
18pub trait InsertSpec: Send + Sync {
19    type Output: Send + 'static;
20
21    #[doc(hidden)]
22    fn execute<'a>(
23        &'a self,
24        executor: &'a dyn PgExecutor,
25    ) -> BoxFuture<'a, Result<Self::Output, sqlx::Error>>;
26}
27
28/// Runtime contract implemented by insert result models.
29pub trait InsertModel: Sized + Send + 'static {
30    type Schema: SchemaAccess;
31    type Values: InsertValueSet;
32
33    /// Schema model name being inserted into.
34    fn model_name() -> &'static str;
35
36    /// Scalar fields returned by the insert statement.
37    fn returning_fields() -> &'static [&'static str];
38
39    /// Decodes the inserted row returned by `RETURNING`.
40    fn from_row(row: &PgRow, prefix: &str) -> Result<Self, sqlx::Error>;
41}
42
43/// Converts a user-provided input into executable insert values.
44pub trait InsertValueSet: Send + 'static {
45    fn into_insert_values(self) -> InsertValues;
46}
47
48impl InsertValueSet for InsertValues {
49    fn into_insert_values(self) -> InsertValues {
50        self
51    }
52}
53
54impl InsertValueSet for () {
55    fn into_insert_values(self) -> InsertValues {
56        InsertValues::new()
57    }
58}
59
60pub trait InsertScalar: Send {
61    fn into_insert_value(self) -> InsertValue;
62}
63
64#[derive(Clone, Debug, Default, PartialEq)]
65pub struct InsertValues {
66    values: Vec<InsertFieldValue>,
67    value_indices: HashMap<String, usize>,
68}
69
70impl InsertValues {
71    pub fn new() -> Self {
72        Self {
73            values: Vec::new(),
74            value_indices: HashMap::new(),
75        }
76    }
77
78    pub fn from_values(values: Vec<(impl Into<String>, InsertValue)>) -> Self {
79        let mut insert_values = Self::new();
80
81        for (name, value) in values {
82            insert_values
83                .push(name, value)
84                .expect("insert field names must be unique");
85        }
86
87        insert_values
88    }
89
90    pub fn push(
91        &mut self,
92        name: impl Into<String>,
93        value: InsertValue,
94    ) -> Result<usize, sqlx::Error> {
95        let name = name.into();
96
97        if self.value_indices.contains_key(&name) {
98            return Err(schema_error(format!("duplicate insert field `{name}`")));
99        }
100
101        let index = self.values.len();
102        self.values.push(InsertFieldValue {
103            name: name.clone(),
104            value,
105        });
106        self.value_indices.insert(name, index);
107        Ok(index)
108    }
109
110    pub fn get(&self, name: &str) -> Option<&InsertValue> {
111        self.value_indices
112            .get(name)
113            .and_then(|index| self.values.get(*index))
114            .map(|field| &field.value)
115    }
116
117    pub fn iter(&self) -> impl Iterator<Item = &InsertFieldValue> {
118        self.values.iter()
119    }
120
121    pub fn len(&self) -> usize {
122        self.values.len()
123    }
124
125    pub fn is_empty(&self) -> bool {
126        self.values.is_empty()
127    }
128}
129
130#[derive(Clone, Debug, PartialEq)]
131pub struct InsertFieldValue {
132    pub name: String,
133    pub value: InsertValue,
134}
135
136#[derive(Clone, Debug, PartialEq)]
137pub enum InsertValue {
138    Null,
139    Int(i64),
140    String(String),
141    Bool(bool),
142    Float(f64),
143    Decimal(Decimal),
144    Bytes(Vec<u8>),
145    DateTime(chrono::DateTime<chrono::Utc>),
146    Uuid(Uuid),
147}
148
149impl From<i64> for InsertValue {
150    fn from(value: i64) -> Self {
151        Self::Int(value)
152    }
153}
154
155impl From<String> for InsertValue {
156    fn from(value: String) -> Self {
157        Self::String(value)
158    }
159}
160
161impl From<&str> for InsertValue {
162    fn from(value: &str) -> Self {
163        Self::String(value.to_owned())
164    }
165}
166
167impl From<bool> for InsertValue {
168    fn from(value: bool) -> Self {
169        Self::Bool(value)
170    }
171}
172
173impl From<f64> for InsertValue {
174    fn from(value: f64) -> Self {
175        Self::Float(value)
176    }
177}
178
179impl From<Decimal> for InsertValue {
180    fn from(value: Decimal) -> Self {
181        Self::Decimal(value)
182    }
183}
184
185impl From<Vec<u8>> for InsertValue {
186    fn from(value: Vec<u8>) -> Self {
187        Self::Bytes(value)
188    }
189}
190
191impl From<&[u8]> for InsertValue {
192    fn from(value: &[u8]) -> Self {
193        Self::Bytes(value.to_vec())
194    }
195}
196
197impl From<chrono::DateTime<chrono::Utc>> for InsertValue {
198    fn from(value: chrono::DateTime<chrono::Utc>) -> Self {
199        Self::DateTime(value)
200    }
201}
202
203impl From<Uuid> for InsertValue {
204    fn from(value: Uuid) -> Self {
205        Self::Uuid(value)
206    }
207}
208
209impl<T> From<Option<T>> for InsertValue
210where
211    T: Into<InsertValue>,
212{
213    fn from(value: Option<T>) -> Self {
214        match value {
215            Some(value) => value.into(),
216            None => Self::Null,
217        }
218    }
219}
220
221impl InsertScalar for i64 {
222    fn into_insert_value(self) -> InsertValue {
223        self.into()
224    }
225}
226
227impl InsertScalar for &str {
228    fn into_insert_value(self) -> InsertValue {
229        self.into()
230    }
231}
232
233impl InsertScalar for bool {
234    fn into_insert_value(self) -> InsertValue {
235        self.into()
236    }
237}
238
239impl InsertScalar for f64 {
240    fn into_insert_value(self) -> InsertValue {
241        self.into()
242    }
243}
244
245impl InsertScalar for Decimal {
246    fn into_insert_value(self) -> InsertValue {
247        self.into()
248    }
249}
250
251impl InsertScalar for Vec<u8> {
252    fn into_insert_value(self) -> InsertValue {
253        self.into()
254    }
255}
256
257impl InsertScalar for &[u8] {
258    fn into_insert_value(self) -> InsertValue {
259        self.into()
260    }
261}
262
263impl InsertScalar for chrono::DateTime<chrono::Utc> {
264    fn into_insert_value(self) -> InsertValue {
265        self.into()
266    }
267}
268
269impl InsertScalar for Uuid {
270    fn into_insert_value(self) -> InsertValue {
271        self.into()
272    }
273}
274
275impl<T> InsertScalar for T
276where
277    T: StringValueType,
278{
279    fn into_insert_value(self) -> InsertValue {
280        InsertValue::String(self.into_db_string())
281    }
282}
283
284impl<T> InsertScalar for Option<T>
285where
286    T: InsertScalar,
287{
288    fn into_insert_value(self) -> InsertValue {
289        match self {
290            Some(value) => value.into_insert_value(),
291            None => InsertValue::Null,
292        }
293    }
294}
295
296/// Executable scalar insert for exactly one row.
297#[derive(Clone, Debug)]
298pub struct Insert<S, T> {
299    values: InsertValues,
300    _marker: PhantomData<(S, T)>,
301}
302
303impl<S, T> Insert<S, T>
304where
305    T: InsertModel<Schema = S>,
306{
307    pub fn new(values: T::Values) -> Self {
308        Self {
309            values: values.into_insert_values(),
310            _marker: PhantomData,
311        }
312    }
313
314    pub fn with_values(values: InsertValues) -> Self {
315        Self {
316            values,
317            _marker: PhantomData,
318        }
319    }
320
321    pub fn values(&self) -> &InsertValues {
322        &self.values
323    }
324}
325
326impl<S, T> Insert<S, T>
327where
328    S: SchemaAccess,
329    T: InsertModel<Schema = S>,
330{
331    pub fn to_sql(&self) -> Result<String, sqlx::Error> {
332        let (sql, _) = build_insert_sql(
333            S::schema(),
334            T::model_name(),
335            &self.values,
336            T::returning_fields(),
337        )?;
338        Ok(sql)
339    }
340}
341
342impl<S, T> InsertSpec for Insert<S, T>
343where
344    S: SchemaAccess,
345    T: InsertModel<Schema = S> + Sync,
346{
347    type Output = T;
348
349    fn execute<'a>(
350        &'a self,
351        executor: &'a dyn PgExecutor,
352    ) -> BoxFuture<'a, Result<Self::Output, sqlx::Error>> {
353        Box::pin(async move {
354            let (sql, bindings) = build_insert_sql(
355                S::schema(),
356                T::model_name(),
357                &self.values,
358                T::returning_fields(),
359            )?;
360            let row = executor
361                .fetch_one(bind_insert(sqlx::query(&sql), &bindings))
362                .await?;
363
364            T::from_row(&row, T::model_name())
365        })
366    }
367}
368
369fn build_insert_sql(
370    schema: &Schema,
371    model_name: &str,
372    values: &InsertValues,
373    returning_fields: &[&'static str],
374) -> Result<(String, Vec<BoundInsertValue>), sqlx::Error> {
375    let model = schema_model(schema, model_name)?;
376
377    validate_insert_values(model, values)?;
378    validate_returning_fields(model, returning_fields)?;
379
380    let ordered_values = ordered_insert_values(model, values);
381    let returning_clause = build_returning_clause(model, returning_fields, model_name)?;
382
383    let sql = if ordered_values.is_empty() {
384        format!(
385            "INSERT INTO {} DEFAULT VALUES RETURNING {}",
386            quoted_ident(model.name()),
387            returning_clause.join(", "),
388        )
389    } else {
390        let columns = ordered_values
391            .iter()
392            .map(|(field, _)| quoted_ident(field.name()))
393            .collect::<Vec<_>>();
394        let placeholders = ordered_values
395            .iter()
396            .enumerate()
397            .map(|(index, (field, value))| {
398                let placeholder = format!("${}", index + 1);
399                match (field.ty(), value) {
400                    (FieldType::Scalar(scalar), InsertValue::Null)
401                        if scalar.scalar() == ScalarType::String && !field.has_db_uuid() =>
402                    {
403                        format!("{placeholder}::text")
404                    }
405                    (FieldType::Scalar(scalar), InsertValue::Null)
406                        if scalar.scalar() == ScalarType::Boolean =>
407                    {
408                        format!("{placeholder}::boolean")
409                    }
410                    (FieldType::Scalar(scalar), InsertValue::Null)
411                        if scalar.scalar() == ScalarType::Float =>
412                    {
413                        format!("{placeholder}::double precision")
414                    }
415                    (FieldType::Scalar(scalar), InsertValue::Null)
416                        if scalar.scalar() == ScalarType::Decimal =>
417                    {
418                        format!("{placeholder}::numeric")
419                    }
420                    (FieldType::Scalar(scalar), InsertValue::Null)
421                        if scalar.scalar() == ScalarType::Bytes =>
422                    {
423                        format!("{placeholder}::bytea")
424                    }
425                    (FieldType::Scalar(scalar), InsertValue::Null)
426                        if scalar.scalar() == ScalarType::DateTime =>
427                    {
428                        format!("{placeholder}::timestamp")
429                    }
430                    (FieldType::Scalar(_), InsertValue::Null) if field.has_db_uuid() => {
431                        format!("{placeholder}::uuid")
432                    }
433                    _ => placeholder,
434                }
435            })
436            .collect::<Vec<_>>();
437
438        format!(
439            "INSERT INTO {} ({}) VALUES ({}) RETURNING {}",
440            quoted_ident(model.name()),
441            columns.join(", "),
442            placeholders.join(", "),
443            returning_clause.join(", "),
444        )
445    };
446
447    let bindings = ordered_values
448        .into_iter()
449        .map(|(field, value)| match (field.ty(), value.clone()) {
450            (FieldType::Scalar(scalar), InsertValue::Null)
451                if scalar.scalar() == ScalarType::String && !field.has_db_uuid() =>
452            {
453                BoundInsertValue::NullString
454            }
455            (FieldType::Scalar(scalar), InsertValue::Null)
456                if scalar.scalar() == ScalarType::Boolean =>
457            {
458                BoundInsertValue::NullBool
459            }
460            (FieldType::Scalar(scalar), InsertValue::Null)
461                if scalar.scalar() == ScalarType::Float =>
462            {
463                BoundInsertValue::NullFloat
464            }
465            (FieldType::Scalar(scalar), InsertValue::Null)
466                if scalar.scalar() == ScalarType::Decimal =>
467            {
468                BoundInsertValue::NullDecimal
469            }
470            (FieldType::Scalar(scalar), InsertValue::Null)
471                if scalar.scalar() == ScalarType::Bytes =>
472            {
473                BoundInsertValue::NullBytes
474            }
475            (FieldType::Scalar(scalar), InsertValue::Null)
476                if scalar.scalar() == ScalarType::DateTime =>
477            {
478                BoundInsertValue::NullDateTime
479            }
480            (FieldType::Scalar(_), InsertValue::Null) if field.has_db_uuid() => {
481                BoundInsertValue::NullUuid
482            }
483            (_, value) => value.into(),
484        })
485        .collect();
486
487    Ok((sql, bindings))
488}
489
490#[derive(Clone, Debug, PartialEq)]
491enum BoundInsertValue {
492    Null,
493    NullString,
494    NullBool,
495    NullFloat,
496    NullDecimal,
497    NullBytes,
498    NullDateTime,
499    NullUuid,
500    Int(i64),
501    String(String),
502    Bool(bool),
503    Float(f64),
504    Decimal(Decimal),
505    Bytes(Vec<u8>),
506    DateTime(chrono::DateTime<chrono::Utc>),
507    Uuid(Uuid),
508}
509
510impl From<InsertValue> for BoundInsertValue {
511    fn from(value: InsertValue) -> Self {
512        match value {
513            InsertValue::Null => Self::Null,
514            InsertValue::Int(value) => Self::Int(value),
515            InsertValue::String(value) => Self::String(value),
516            InsertValue::Bool(value) => Self::Bool(value),
517            InsertValue::Float(value) => Self::Float(value),
518            InsertValue::Decimal(value) => Self::Decimal(value),
519            InsertValue::Bytes(value) => Self::Bytes(value),
520            InsertValue::DateTime(value) => Self::DateTime(value),
521            InsertValue::Uuid(value) => Self::Uuid(value),
522        }
523    }
524}
525
526fn validate_insert_values(model: &Model, values: &InsertValues) -> Result<(), sqlx::Error> {
527    for provided in values.iter() {
528        let field = model.field_named(&provided.name).ok_or_else(|| {
529            schema_error(format!(
530                "unknown field `{}` in insert for model `{}`",
531                provided.name,
532                model.name()
533            ))
534        })?;
535
536        if field.kind().is_relation() {
537            return Err(schema_error(format!(
538                "relation field `{}` cannot be written in insert for model `{}`",
539                field.name(),
540                model.name()
541            )));
542        }
543
544        if !insert_value_matches_field(&provided.value, field) {
545            return Err(schema_error(format!(
546                "insert value for field `{}` is incompatible with schema type `{}` on model `{}`",
547                field.name(),
548                field.ty().name(),
549                model.name()
550            )));
551        }
552    }
553
554    for field in model.fields() {
555        if field.kind().is_relation() {
556            continue;
557        }
558
559        if values.get(field.name()).is_none() && !field_can_be_omitted(field) {
560            return Err(schema_error(format!(
561                "missing required scalar field `{}` in insert for model `{}`",
562                field.name(),
563                model.name()
564            )));
565        }
566    }
567
568    Ok(())
569}
570
571fn validate_returning_fields(
572    model: &Model,
573    returning_fields: &[&'static str],
574) -> Result<(), sqlx::Error> {
575    if returning_fields.is_empty() {
576        return Err(schema_error(format!(
577            "insert on model `{}` must return at least one scalar field",
578            model.name()
579        )));
580    }
581
582    let mut seen = HashSet::new();
583
584    for field_name in returning_fields {
585        if !seen.insert(*field_name) {
586            return Err(schema_error(format!(
587                "duplicate returning field `{field_name}` in insert for model `{}`",
588                model.name()
589            )));
590        }
591
592        let field = model.field_named(field_name).ok_or_else(|| {
593            schema_error(format!(
594                "unknown returning field `{field_name}` in insert for model `{}`",
595                model.name()
596            ))
597        })?;
598
599        if field.kind().is_relation() {
600            return Err(schema_error(format!(
601                "relation field `{field_name}` cannot be returned from scalar insert for model `{}`",
602                model.name()
603            )));
604        }
605    }
606
607    Ok(())
608}
609
610fn ordered_insert_values<'a>(
611    model: &'a Model,
612    values: &'a InsertValues,
613) -> Vec<(&'a Field, &'a InsertValue)> {
614    let mut ordered = Vec::new();
615
616    for field in model.fields() {
617        if field.kind().is_relation() {
618            continue;
619        }
620
621        if let Some(value) = values.get(field.name()) {
622            ordered.push((field, value));
623        }
624    }
625
626    ordered
627}
628
629fn build_returning_clause(
630    model: &Model,
631    returning_fields: &[&'static str],
632    prefix: &str,
633) -> Result<Vec<String>, sqlx::Error> {
634    let mut selections = Vec::with_capacity(returning_fields.len());
635
636    for field_name in returning_fields {
637        let field = model.field_named(field_name).ok_or_else(|| {
638            schema_error(format!(
639                "unknown returning field `{field_name}` in insert for model `{}`",
640                model.name()
641            ))
642        })?;
643
644        let scalar = scalar_field_type(field).ok_or_else(|| {
645            schema_error(format!(
646                "relation field `{field_name}` cannot be returned from scalar insert for model `{}`",
647                model.name()
648            ))
649        })?;
650
651        let alias = alias_name(prefix, field_name);
652        selections.push(select_expr(model.name(), field_name, scalar, &alias));
653    }
654
655    Ok(selections)
656}
657
658fn field_can_be_omitted(field: &Field) -> bool {
659    field.ty().is_optional() || field_has_supported_default(field)
660}
661
662fn field_has_supported_default(field: &Field) -> bool {
663    field.attributes().iter().any(|attribute| {
664        matches!(
665            attribute,
666            Attribute::Default(default)
667                if matches!(
668                    default.function(),
669                    DefaultFunction::Autoincrement | DefaultFunction::Now
670                )
671        )
672    })
673}
674
675fn scalar_field_type(field: &Field) -> Option<ScalarType> {
676    match field.ty() {
677        FieldType::Scalar(scalar) => Some(scalar.scalar()),
678        FieldType::Relation { .. } => None,
679    }
680}
681
682fn insert_value_matches_field(value: &InsertValue, field: &Field) -> bool {
683    let FieldType::Scalar(scalar) = field.ty() else {
684        return false;
685    };
686
687    match value {
688        InsertValue::Null => scalar.optional(),
689        InsertValue::Int(_) => {
690            matches!(scalar.scalar(), ScalarType::Int | ScalarType::BigInt)
691        }
692        InsertValue::String(_) => scalar.scalar() == ScalarType::String && !field.has_db_uuid(),
693        InsertValue::Bool(_) => scalar.scalar() == ScalarType::Boolean,
694        InsertValue::Float(_) => scalar.scalar() == ScalarType::Float,
695        InsertValue::Decimal(_) => scalar.scalar() == ScalarType::Decimal,
696        InsertValue::Bytes(_) => scalar.scalar() == ScalarType::Bytes,
697        InsertValue::DateTime(_) => scalar.scalar() == ScalarType::DateTime,
698        InsertValue::Uuid(_) => scalar.scalar() == ScalarType::String && field.has_db_uuid(),
699    }
700}
701
702fn schema_model<'a>(schema: &'a Schema, requested: &str) -> Result<&'a Model, sqlx::Error> {
703    match schema.resolve_model(requested) {
704        Resolution::Found(model) => Ok(model),
705        Resolution::NotFound => Err(schema_error(format!(
706            "unknown model `{requested}` in insert"
707        ))),
708        Resolution::Ambiguous(models) => {
709            let candidates = models
710                .into_iter()
711                .map(|model| format!("`{}`", model.name()))
712                .collect::<Vec<_>>()
713                .join(", ");
714
715            Err(schema_error(format!(
716                "ambiguous model `{requested}` in insert; matches {candidates}"
717            )))
718        }
719    }
720}
721
722fn bind_insert<'q>(
723    mut query: SqlxQuery<'q, Postgres, PgArguments>,
724    bindings: &'q [BoundInsertValue],
725) -> SqlxQuery<'q, Postgres, PgArguments> {
726    for binding in bindings {
727        query = match binding {
728            BoundInsertValue::Null => query.bind(Option::<i64>::None),
729            BoundInsertValue::NullString => query.bind(Option::<String>::None),
730            BoundInsertValue::NullBool => query.bind(Option::<bool>::None),
731            BoundInsertValue::NullFloat => query.bind(Option::<f64>::None),
732            BoundInsertValue::NullDecimal => query.bind(Option::<Decimal>::None),
733            BoundInsertValue::NullBytes => query.bind(Option::<Vec<u8>>::None),
734            BoundInsertValue::NullDateTime => {
735                query.bind(Option::<chrono::DateTime<chrono::Utc>>::None)
736            }
737            BoundInsertValue::NullUuid => query.bind(Option::<Uuid>::None),
738            BoundInsertValue::Int(value) => query.bind(*value),
739            BoundInsertValue::String(value) => query.bind(value),
740            BoundInsertValue::Bool(value) => query.bind(*value),
741            BoundInsertValue::Float(value) => query.bind(*value),
742            BoundInsertValue::Decimal(value) => query.bind(*value),
743            BoundInsertValue::Bytes(value) => query.bind(value),
744            BoundInsertValue::DateTime(value) => query.bind(*value),
745            BoundInsertValue::Uuid(value) => query.bind(*value),
746        };
747    }
748
749    query
750}