1use std::collections::{HashMap, HashSet};
2use std::marker::PhantomData;
3
4use rust_decimal::Decimal;
5use sqlx::postgres::{PgArguments, PgRow};
6use sqlx::{Postgres, query::Query as SqlxQuery};
7use uuid::Uuid;
8
9use crate::PgExecutor;
10use crate::query::{
11 BoxFuture, SchemaAccess, StringValueType, alias_name, quoted_ident, schema_error, select_expr,
12};
13use crate::schema::{
14 Attribute, DefaultFunction, Field, FieldType, Model, Resolution, ScalarType, Schema,
15};
16
17pub trait InsertSpec: Send + Sync {
19 type Output: Send + 'static;
20
21 #[doc(hidden)]
22 fn execute<'a>(
23 &'a self,
24 executor: &'a dyn PgExecutor,
25 ) -> BoxFuture<'a, Result<Self::Output, sqlx::Error>>;
26}
27
28pub trait InsertModel: Sized + Send + 'static {
30 type Schema: SchemaAccess;
31 type Values: InsertValueSet;
32
33 fn model_name() -> &'static str;
35
36 fn returning_fields() -> &'static [&'static str];
38
39 fn from_row(row: &PgRow, prefix: &str) -> Result<Self, sqlx::Error>;
41}
42
43pub trait InsertValueSet: Send + 'static {
45 fn into_insert_values(self) -> InsertValues;
46}
47
48impl InsertValueSet for InsertValues {
49 fn into_insert_values(self) -> InsertValues {
50 self
51 }
52}
53
54impl InsertValueSet for () {
55 fn into_insert_values(self) -> InsertValues {
56 InsertValues::new()
57 }
58}
59
60pub trait InsertScalar: Send {
61 fn into_insert_value(self) -> InsertValue;
62}
63
64#[derive(Clone, Debug, Default, PartialEq)]
65pub struct InsertValues {
66 values: Vec<InsertFieldValue>,
67 value_indices: HashMap<String, usize>,
68}
69
70impl InsertValues {
71 pub fn new() -> Self {
72 Self {
73 values: Vec::new(),
74 value_indices: HashMap::new(),
75 }
76 }
77
78 pub fn from_values(values: Vec<(impl Into<String>, InsertValue)>) -> Self {
79 let mut insert_values = Self::new();
80
81 for (name, value) in values {
82 insert_values
83 .push(name, value)
84 .expect("insert field names must be unique");
85 }
86
87 insert_values
88 }
89
90 pub fn push(
91 &mut self,
92 name: impl Into<String>,
93 value: InsertValue,
94 ) -> Result<usize, sqlx::Error> {
95 let name = name.into();
96
97 if self.value_indices.contains_key(&name) {
98 return Err(schema_error(format!("duplicate insert field `{name}`")));
99 }
100
101 let index = self.values.len();
102 self.values.push(InsertFieldValue {
103 name: name.clone(),
104 value,
105 });
106 self.value_indices.insert(name, index);
107 Ok(index)
108 }
109
110 pub fn get(&self, name: &str) -> Option<&InsertValue> {
111 self.value_indices
112 .get(name)
113 .and_then(|index| self.values.get(*index))
114 .map(|field| &field.value)
115 }
116
117 pub fn iter(&self) -> impl Iterator<Item = &InsertFieldValue> {
118 self.values.iter()
119 }
120
121 pub fn len(&self) -> usize {
122 self.values.len()
123 }
124
125 pub fn is_empty(&self) -> bool {
126 self.values.is_empty()
127 }
128}
129
130#[derive(Clone, Debug, PartialEq)]
131pub struct InsertFieldValue {
132 pub name: String,
133 pub value: InsertValue,
134}
135
136#[derive(Clone, Debug, PartialEq)]
137pub enum InsertValue {
138 Null,
139 Int(i64),
140 String(String),
141 Bool(bool),
142 Float(f64),
143 Decimal(Decimal),
144 Bytes(Vec<u8>),
145 DateTime(chrono::DateTime<chrono::Utc>),
146 Uuid(Uuid),
147}
148
149impl From<i64> for InsertValue {
150 fn from(value: i64) -> Self {
151 Self::Int(value)
152 }
153}
154
155impl From<String> for InsertValue {
156 fn from(value: String) -> Self {
157 Self::String(value)
158 }
159}
160
161impl From<&str> for InsertValue {
162 fn from(value: &str) -> Self {
163 Self::String(value.to_owned())
164 }
165}
166
167impl From<bool> for InsertValue {
168 fn from(value: bool) -> Self {
169 Self::Bool(value)
170 }
171}
172
173impl From<f64> for InsertValue {
174 fn from(value: f64) -> Self {
175 Self::Float(value)
176 }
177}
178
179impl From<Decimal> for InsertValue {
180 fn from(value: Decimal) -> Self {
181 Self::Decimal(value)
182 }
183}
184
185impl From<Vec<u8>> for InsertValue {
186 fn from(value: Vec<u8>) -> Self {
187 Self::Bytes(value)
188 }
189}
190
191impl From<&[u8]> for InsertValue {
192 fn from(value: &[u8]) -> Self {
193 Self::Bytes(value.to_vec())
194 }
195}
196
197impl From<chrono::DateTime<chrono::Utc>> for InsertValue {
198 fn from(value: chrono::DateTime<chrono::Utc>) -> Self {
199 Self::DateTime(value)
200 }
201}
202
203impl From<Uuid> for InsertValue {
204 fn from(value: Uuid) -> Self {
205 Self::Uuid(value)
206 }
207}
208
209impl<T> From<Option<T>> for InsertValue
210where
211 T: Into<InsertValue>,
212{
213 fn from(value: Option<T>) -> Self {
214 match value {
215 Some(value) => value.into(),
216 None => Self::Null,
217 }
218 }
219}
220
221impl InsertScalar for i64 {
222 fn into_insert_value(self) -> InsertValue {
223 self.into()
224 }
225}
226
227impl InsertScalar for &str {
228 fn into_insert_value(self) -> InsertValue {
229 self.into()
230 }
231}
232
233impl InsertScalar for bool {
234 fn into_insert_value(self) -> InsertValue {
235 self.into()
236 }
237}
238
239impl InsertScalar for f64 {
240 fn into_insert_value(self) -> InsertValue {
241 self.into()
242 }
243}
244
245impl InsertScalar for Decimal {
246 fn into_insert_value(self) -> InsertValue {
247 self.into()
248 }
249}
250
251impl InsertScalar for Vec<u8> {
252 fn into_insert_value(self) -> InsertValue {
253 self.into()
254 }
255}
256
257impl InsertScalar for &[u8] {
258 fn into_insert_value(self) -> InsertValue {
259 self.into()
260 }
261}
262
263impl InsertScalar for chrono::DateTime<chrono::Utc> {
264 fn into_insert_value(self) -> InsertValue {
265 self.into()
266 }
267}
268
269impl InsertScalar for Uuid {
270 fn into_insert_value(self) -> InsertValue {
271 self.into()
272 }
273}
274
275impl<T> InsertScalar for T
276where
277 T: StringValueType,
278{
279 fn into_insert_value(self) -> InsertValue {
280 InsertValue::String(self.into_db_string())
281 }
282}
283
284impl<T> InsertScalar for Option<T>
285where
286 T: InsertScalar,
287{
288 fn into_insert_value(self) -> InsertValue {
289 match self {
290 Some(value) => value.into_insert_value(),
291 None => InsertValue::Null,
292 }
293 }
294}
295
296#[derive(Clone, Debug)]
298pub struct Insert<S, T> {
299 values: InsertValues,
300 _marker: PhantomData<(S, T)>,
301}
302
303impl<S, T> Insert<S, T>
304where
305 T: InsertModel<Schema = S>,
306{
307 pub fn new(values: T::Values) -> Self {
308 Self {
309 values: values.into_insert_values(),
310 _marker: PhantomData,
311 }
312 }
313
314 pub fn with_values(values: InsertValues) -> Self {
315 Self {
316 values,
317 _marker: PhantomData,
318 }
319 }
320
321 pub fn values(&self) -> &InsertValues {
322 &self.values
323 }
324}
325
326impl<S, T> Insert<S, T>
327where
328 S: SchemaAccess,
329 T: InsertModel<Schema = S>,
330{
331 pub fn to_sql(&self) -> Result<String, sqlx::Error> {
332 let (sql, _) = build_insert_sql(
333 S::schema(),
334 T::model_name(),
335 &self.values,
336 T::returning_fields(),
337 )?;
338 Ok(sql)
339 }
340}
341
342impl<S, T> InsertSpec for Insert<S, T>
343where
344 S: SchemaAccess,
345 T: InsertModel<Schema = S> + Sync,
346{
347 type Output = T;
348
349 fn execute<'a>(
350 &'a self,
351 executor: &'a dyn PgExecutor,
352 ) -> BoxFuture<'a, Result<Self::Output, sqlx::Error>> {
353 Box::pin(async move {
354 let (sql, bindings) = build_insert_sql(
355 S::schema(),
356 T::model_name(),
357 &self.values,
358 T::returning_fields(),
359 )?;
360 let row = executor
361 .fetch_one(bind_insert(sqlx::query(&sql), &bindings))
362 .await?;
363
364 T::from_row(&row, T::model_name())
365 })
366 }
367}
368
369fn build_insert_sql(
370 schema: &Schema,
371 model_name: &str,
372 values: &InsertValues,
373 returning_fields: &[&'static str],
374) -> Result<(String, Vec<BoundInsertValue>), sqlx::Error> {
375 let model = schema_model(schema, model_name)?;
376
377 validate_insert_values(model, values)?;
378 validate_returning_fields(model, returning_fields)?;
379
380 let ordered_values = ordered_insert_values(model, values);
381 let returning_clause = build_returning_clause(model, returning_fields, model_name)?;
382
383 let sql = if ordered_values.is_empty() {
384 format!(
385 "INSERT INTO {} DEFAULT VALUES RETURNING {}",
386 quoted_ident(model.name()),
387 returning_clause.join(", "),
388 )
389 } else {
390 let columns = ordered_values
391 .iter()
392 .map(|(field, _)| quoted_ident(field.name()))
393 .collect::<Vec<_>>();
394 let placeholders = ordered_values
395 .iter()
396 .enumerate()
397 .map(|(index, (field, value))| {
398 let placeholder = format!("${}", index + 1);
399 match (field.ty(), value) {
400 (FieldType::Scalar(scalar), InsertValue::Null)
401 if scalar.scalar() == ScalarType::String && !field.has_db_uuid() =>
402 {
403 format!("{placeholder}::text")
404 }
405 (FieldType::Scalar(scalar), InsertValue::Null)
406 if scalar.scalar() == ScalarType::Boolean =>
407 {
408 format!("{placeholder}::boolean")
409 }
410 (FieldType::Scalar(scalar), InsertValue::Null)
411 if scalar.scalar() == ScalarType::Float =>
412 {
413 format!("{placeholder}::double precision")
414 }
415 (FieldType::Scalar(scalar), InsertValue::Null)
416 if scalar.scalar() == ScalarType::Decimal =>
417 {
418 format!("{placeholder}::numeric")
419 }
420 (FieldType::Scalar(scalar), InsertValue::Null)
421 if scalar.scalar() == ScalarType::Bytes =>
422 {
423 format!("{placeholder}::bytea")
424 }
425 (FieldType::Scalar(scalar), InsertValue::Null)
426 if scalar.scalar() == ScalarType::DateTime =>
427 {
428 format!("{placeholder}::timestamp")
429 }
430 (FieldType::Scalar(_), InsertValue::Null) if field.has_db_uuid() => {
431 format!("{placeholder}::uuid")
432 }
433 _ => placeholder,
434 }
435 })
436 .collect::<Vec<_>>();
437
438 format!(
439 "INSERT INTO {} ({}) VALUES ({}) RETURNING {}",
440 quoted_ident(model.name()),
441 columns.join(", "),
442 placeholders.join(", "),
443 returning_clause.join(", "),
444 )
445 };
446
447 let bindings = ordered_values
448 .into_iter()
449 .map(|(field, value)| match (field.ty(), value.clone()) {
450 (FieldType::Scalar(scalar), InsertValue::Null)
451 if scalar.scalar() == ScalarType::String && !field.has_db_uuid() =>
452 {
453 BoundInsertValue::NullString
454 }
455 (FieldType::Scalar(scalar), InsertValue::Null)
456 if scalar.scalar() == ScalarType::Boolean =>
457 {
458 BoundInsertValue::NullBool
459 }
460 (FieldType::Scalar(scalar), InsertValue::Null)
461 if scalar.scalar() == ScalarType::Float =>
462 {
463 BoundInsertValue::NullFloat
464 }
465 (FieldType::Scalar(scalar), InsertValue::Null)
466 if scalar.scalar() == ScalarType::Decimal =>
467 {
468 BoundInsertValue::NullDecimal
469 }
470 (FieldType::Scalar(scalar), InsertValue::Null)
471 if scalar.scalar() == ScalarType::Bytes =>
472 {
473 BoundInsertValue::NullBytes
474 }
475 (FieldType::Scalar(scalar), InsertValue::Null)
476 if scalar.scalar() == ScalarType::DateTime =>
477 {
478 BoundInsertValue::NullDateTime
479 }
480 (FieldType::Scalar(_), InsertValue::Null) if field.has_db_uuid() => {
481 BoundInsertValue::NullUuid
482 }
483 (_, value) => value.into(),
484 })
485 .collect();
486
487 Ok((sql, bindings))
488}
489
490#[derive(Clone, Debug, PartialEq)]
491enum BoundInsertValue {
492 Null,
493 NullString,
494 NullBool,
495 NullFloat,
496 NullDecimal,
497 NullBytes,
498 NullDateTime,
499 NullUuid,
500 Int(i64),
501 String(String),
502 Bool(bool),
503 Float(f64),
504 Decimal(Decimal),
505 Bytes(Vec<u8>),
506 DateTime(chrono::DateTime<chrono::Utc>),
507 Uuid(Uuid),
508}
509
510impl From<InsertValue> for BoundInsertValue {
511 fn from(value: InsertValue) -> Self {
512 match value {
513 InsertValue::Null => Self::Null,
514 InsertValue::Int(value) => Self::Int(value),
515 InsertValue::String(value) => Self::String(value),
516 InsertValue::Bool(value) => Self::Bool(value),
517 InsertValue::Float(value) => Self::Float(value),
518 InsertValue::Decimal(value) => Self::Decimal(value),
519 InsertValue::Bytes(value) => Self::Bytes(value),
520 InsertValue::DateTime(value) => Self::DateTime(value),
521 InsertValue::Uuid(value) => Self::Uuid(value),
522 }
523 }
524}
525
526fn validate_insert_values(model: &Model, values: &InsertValues) -> Result<(), sqlx::Error> {
527 for provided in values.iter() {
528 let field = model.field_named(&provided.name).ok_or_else(|| {
529 schema_error(format!(
530 "unknown field `{}` in insert for model `{}`",
531 provided.name,
532 model.name()
533 ))
534 })?;
535
536 if field.kind().is_relation() {
537 return Err(schema_error(format!(
538 "relation field `{}` cannot be written in insert for model `{}`",
539 field.name(),
540 model.name()
541 )));
542 }
543
544 if !insert_value_matches_field(&provided.value, field) {
545 return Err(schema_error(format!(
546 "insert value for field `{}` is incompatible with schema type `{}` on model `{}`",
547 field.name(),
548 field.ty().name(),
549 model.name()
550 )));
551 }
552 }
553
554 for field in model.fields() {
555 if field.kind().is_relation() {
556 continue;
557 }
558
559 if values.get(field.name()).is_none() && !field_can_be_omitted(field) {
560 return Err(schema_error(format!(
561 "missing required scalar field `{}` in insert for model `{}`",
562 field.name(),
563 model.name()
564 )));
565 }
566 }
567
568 Ok(())
569}
570
571fn validate_returning_fields(
572 model: &Model,
573 returning_fields: &[&'static str],
574) -> Result<(), sqlx::Error> {
575 if returning_fields.is_empty() {
576 return Err(schema_error(format!(
577 "insert on model `{}` must return at least one scalar field",
578 model.name()
579 )));
580 }
581
582 let mut seen = HashSet::new();
583
584 for field_name in returning_fields {
585 if !seen.insert(*field_name) {
586 return Err(schema_error(format!(
587 "duplicate returning field `{field_name}` in insert for model `{}`",
588 model.name()
589 )));
590 }
591
592 let field = model.field_named(field_name).ok_or_else(|| {
593 schema_error(format!(
594 "unknown returning field `{field_name}` in insert for model `{}`",
595 model.name()
596 ))
597 })?;
598
599 if field.kind().is_relation() {
600 return Err(schema_error(format!(
601 "relation field `{field_name}` cannot be returned from scalar insert for model `{}`",
602 model.name()
603 )));
604 }
605 }
606
607 Ok(())
608}
609
610fn ordered_insert_values<'a>(
611 model: &'a Model,
612 values: &'a InsertValues,
613) -> Vec<(&'a Field, &'a InsertValue)> {
614 let mut ordered = Vec::new();
615
616 for field in model.fields() {
617 if field.kind().is_relation() {
618 continue;
619 }
620
621 if let Some(value) = values.get(field.name()) {
622 ordered.push((field, value));
623 }
624 }
625
626 ordered
627}
628
629fn build_returning_clause(
630 model: &Model,
631 returning_fields: &[&'static str],
632 prefix: &str,
633) -> Result<Vec<String>, sqlx::Error> {
634 let mut selections = Vec::with_capacity(returning_fields.len());
635
636 for field_name in returning_fields {
637 let field = model.field_named(field_name).ok_or_else(|| {
638 schema_error(format!(
639 "unknown returning field `{field_name}` in insert for model `{}`",
640 model.name()
641 ))
642 })?;
643
644 let scalar = scalar_field_type(field).ok_or_else(|| {
645 schema_error(format!(
646 "relation field `{field_name}` cannot be returned from scalar insert for model `{}`",
647 model.name()
648 ))
649 })?;
650
651 let alias = alias_name(prefix, field_name);
652 selections.push(select_expr(model.name(), field_name, scalar, &alias));
653 }
654
655 Ok(selections)
656}
657
658fn field_can_be_omitted(field: &Field) -> bool {
659 field.ty().is_optional() || field_has_supported_default(field)
660}
661
662fn field_has_supported_default(field: &Field) -> bool {
663 field.attributes().iter().any(|attribute| {
664 matches!(
665 attribute,
666 Attribute::Default(default)
667 if matches!(
668 default.function(),
669 DefaultFunction::Autoincrement | DefaultFunction::Now
670 )
671 )
672 })
673}
674
675fn scalar_field_type(field: &Field) -> Option<ScalarType> {
676 match field.ty() {
677 FieldType::Scalar(scalar) => Some(scalar.scalar()),
678 FieldType::Relation { .. } => None,
679 }
680}
681
682fn insert_value_matches_field(value: &InsertValue, field: &Field) -> bool {
683 let FieldType::Scalar(scalar) = field.ty() else {
684 return false;
685 };
686
687 match value {
688 InsertValue::Null => scalar.optional(),
689 InsertValue::Int(_) => {
690 matches!(scalar.scalar(), ScalarType::Int | ScalarType::BigInt)
691 }
692 InsertValue::String(_) => scalar.scalar() == ScalarType::String && !field.has_db_uuid(),
693 InsertValue::Bool(_) => scalar.scalar() == ScalarType::Boolean,
694 InsertValue::Float(_) => scalar.scalar() == ScalarType::Float,
695 InsertValue::Decimal(_) => scalar.scalar() == ScalarType::Decimal,
696 InsertValue::Bytes(_) => scalar.scalar() == ScalarType::Bytes,
697 InsertValue::DateTime(_) => scalar.scalar() == ScalarType::DateTime,
698 InsertValue::Uuid(_) => scalar.scalar() == ScalarType::String && field.has_db_uuid(),
699 }
700}
701
702fn schema_model<'a>(schema: &'a Schema, requested: &str) -> Result<&'a Model, sqlx::Error> {
703 match schema.resolve_model(requested) {
704 Resolution::Found(model) => Ok(model),
705 Resolution::NotFound => Err(schema_error(format!(
706 "unknown model `{requested}` in insert"
707 ))),
708 Resolution::Ambiguous(models) => {
709 let candidates = models
710 .into_iter()
711 .map(|model| format!("`{}`", model.name()))
712 .collect::<Vec<_>>()
713 .join(", ");
714
715 Err(schema_error(format!(
716 "ambiguous model `{requested}` in insert; matches {candidates}"
717 )))
718 }
719 }
720}
721
722fn bind_insert<'q>(
723 mut query: SqlxQuery<'q, Postgres, PgArguments>,
724 bindings: &'q [BoundInsertValue],
725) -> SqlxQuery<'q, Postgres, PgArguments> {
726 for binding in bindings {
727 query = match binding {
728 BoundInsertValue::Null => query.bind(Option::<i64>::None),
729 BoundInsertValue::NullString => query.bind(Option::<String>::None),
730 BoundInsertValue::NullBool => query.bind(Option::<bool>::None),
731 BoundInsertValue::NullFloat => query.bind(Option::<f64>::None),
732 BoundInsertValue::NullDecimal => query.bind(Option::<Decimal>::None),
733 BoundInsertValue::NullBytes => query.bind(Option::<Vec<u8>>::None),
734 BoundInsertValue::NullDateTime => {
735 query.bind(Option::<chrono::DateTime<chrono::Utc>>::None)
736 }
737 BoundInsertValue::NullUuid => query.bind(Option::<Uuid>::None),
738 BoundInsertValue::Int(value) => query.bind(*value),
739 BoundInsertValue::String(value) => query.bind(value),
740 BoundInsertValue::Bool(value) => query.bind(*value),
741 BoundInsertValue::Float(value) => query.bind(*value),
742 BoundInsertValue::Decimal(value) => query.bind(*value),
743 BoundInsertValue::Bytes(value) => query.bind(value),
744 BoundInsertValue::DateTime(value) => query.bind(*value),
745 BoundInsertValue::Uuid(value) => query.bind(*value),
746 };
747 }
748
749 query
750}