sqlmo/query/
insert.rs

1use crate::query::Expr;
2use crate::util::SqlExtension;
3use crate::{Dialect, Select, ToSql};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum OnConflict {
8    Ignore,
9    Abort,
10    /// Only valid for Sqlite, because we
11    Replace,
12    /// Only valid for Postgres
13    DoUpdate {
14        conflict: Conflict,
15        updates: Vec<(String, Expr)>,
16    },
17    /// Only valid for Postgres
18    DoUpdateAllRows {
19        conflict: Conflict,
20        alternate_values: HashMap<String, Expr>,
21        ignore_columns: Vec<String>,
22    },
23}
24
25impl OnConflict {
26    pub fn do_update_all_rows(columns: &[&str]) -> Self {
27        OnConflict::DoUpdateAllRows {
28            conflict: Conflict::Columns(columns.iter().map(|c| c.to_string()).collect()),
29            alternate_values: HashMap::new(),
30            ignore_columns: Vec::new(),
31        }
32    }
33
34    pub fn do_update_on_pkey(pkey: &str) -> Self {
35        OnConflict::DoUpdateAllRows {
36            conflict: Conflict::Columns(vec![pkey.to_string()]),
37            alternate_values: HashMap::new(),
38            ignore_columns: Vec::new(),
39        }
40    }
41
42    pub fn alternate_value<V: Into<Expr>>(mut self, column: &str, value: V) -> Self {
43        match &mut self {
44            OnConflict::DoUpdateAllRows {
45                alternate_values, ..
46            } => {
47                alternate_values.insert(column.to_string(), value.into());
48            }
49            _ => panic!("alternate_value is only valid for DoUpdate"),
50        }
51        self
52    }
53}
54
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum Conflict {
57    Columns(Vec<String>),
58    ConstraintName(String),
59    NoTarget,
60}
61
62impl Conflict {
63    pub fn columns(t: impl IntoIterator<Item = impl Into<String>>) -> Self {
64        Conflict::Columns(t.into_iter().map(|c| c.into()).collect())
65    }
66
67    pub fn as_columns(&self) -> Option<&Vec<String>> {
68        match self {
69            Conflict::Columns(c) => Some(c),
70            _ => None,
71        }
72    }
73}
74
75impl Default for OnConflict {
76    fn default() -> Self {
77        OnConflict::Abort
78    }
79}
80
81impl ToSql for Conflict {
82    fn write_sql(&self, buf: &mut String, _dialect: Dialect) {
83        match self {
84            Conflict::Columns(c) => {
85                buf.push('(');
86                buf.push_quoted_sequence(c, ", ");
87                buf.push(')');
88            }
89            Conflict::ConstraintName(name) => {
90                buf.push_str("ON CONSTRAINT ");
91                buf.push_quoted(name);
92            }
93            Conflict::NoTarget => {}
94        }
95    }
96}
97
98impl ToSql for Values {
99    fn write_sql(&self, buf: &mut String, dialect: Dialect) {
100        match self {
101            Values::Values(values) => {
102                let mut first_value = true;
103                for value in values {
104                    if !first_value {
105                        buf.push_str(", ");
106                    }
107                    let mut first = true;
108                    buf.push('(');
109                    for v in &value.0 {
110                        if !first {
111                            buf.push_str(", ");
112                        }
113                        buf.push_str(v);
114                        first = false;
115                    }
116                    buf.push(')');
117                    first_value = false;
118                }
119            }
120            Values::Select(select) => {
121                buf.push_sql(select, dialect);
122            }
123            Values::DefaultValues => {
124                buf.push_str("DEFAULT VALUES");
125            }
126        }
127    }
128}
129
130#[derive(Debug, Clone, PartialEq, Eq)]
131pub struct Value(Vec<String>);
132
133impl Value {
134    pub fn with(values: &[&str]) -> Self {
135        Self(values.into_iter().map(|v| v.to_string()).collect())
136    }
137
138    pub fn new() -> Self {
139        Self(Vec::new())
140    }
141
142    pub fn column(mut self, value: &str) -> Self {
143        self.0.push(value.to_string());
144        self
145    }
146
147    pub fn placeholders(mut self, count: usize, dialect: Dialect) -> Self {
148        use Dialect::*;
149        for i in 1..(count + 1) {
150            match dialect {
151                Postgres => self.0.push(format!("${}", i)),
152                Mysql | Sqlite => self.0.push("?".to_string()),
153            }
154        }
155        self
156    }
157}
158
159impl From<Vec<String>> for Value {
160    fn from(values: Vec<String>) -> Self {
161        Self(values)
162    }
163}
164
165#[derive(Debug, Clone, PartialEq, Eq)]
166pub enum Values {
167    Values(Vec<Value>),
168    Select(Select),
169    DefaultValues,
170}
171
172impl From<&[&[&'static str]]> for Values {
173    fn from(values: &[&[&'static str]]) -> Self {
174        Self::Values(values.into_iter().map(|v| Value::with(v)).collect())
175    }
176}
177
178impl From<&[&'static str]> for Values {
179    fn from(values: &[&'static str]) -> Self {
180        Self::Values(vec![Value::with(values)])
181    }
182}
183
184impl Values {
185    pub fn new_value(value: Value) -> Self {
186        Self::Values(vec![value])
187    }
188
189    pub fn select(select: Select) -> Self {
190        Self::Select(select)
191    }
192
193    pub fn default_values() -> Self {
194        Self::DefaultValues
195    }
196
197    pub fn value(mut self, value: Value) -> Self {
198        match &mut self {
199            Self::Values(values) => values.push(value),
200            _ => panic!("Cannot add value to non-values"),
201        }
202        self
203    }
204}
205
206#[derive(Debug, Clone, PartialEq, Eq)]
207pub struct Insert {
208    pub schema: Option<String>,
209    pub table: String,
210    pub columns: Vec<String>,
211    pub values: Values,
212    pub on_conflict: OnConflict,
213    pub returning: Vec<String>,
214}
215
216impl Insert {
217    pub fn new(table: &str) -> Self {
218        Self {
219            schema: None,
220            table: table.to_string(),
221            columns: Vec::new(),
222            values: Values::DefaultValues,
223            on_conflict: OnConflict::default(),
224            returning: Vec::new(),
225        }
226    }
227
228    pub fn schema(mut self, schema: &str) -> Self {
229        self.schema = Some(schema.to_string());
230        self
231    }
232
233    pub fn column(mut self, column: &str) -> Self {
234        self.columns.push(column.to_string());
235        self
236    }
237
238    pub fn values(mut self, value: Values) -> Self {
239        self.values = value;
240        self
241    }
242
243    pub fn columns(mut self, columns: &[&str]) -> Self {
244        self.columns = columns.iter().map(|c| c.to_string()).collect();
245        self
246    }
247
248    pub fn placeholder_for_each_column(mut self, dialect: Dialect) -> Self {
249        self.values = Values::new_value(Value::new().placeholders(self.columns.len(), dialect));
250        self
251    }
252
253    #[deprecated(note = "Use .values(Values::from(...)) instead")]
254    pub fn one_value(mut self, values: &[&str]) -> Self {
255        self.values = Values::Values(vec![Value::with(values)]);
256        self
257    }
258
259    pub fn on_conflict(mut self, on_conflict: OnConflict) -> Self {
260        self.on_conflict = on_conflict;
261        self
262    }
263
264    pub fn returning(mut self, returning: &[&str]) -> Self {
265        self.returning = returning.iter().map(|r| r.to_string()).collect();
266        self
267    }
268}
269
270impl ToSql for Insert {
271    fn write_sql(&self, buf: &mut String, dialect: Dialect) {
272        use Dialect::*;
273        use OnConflict::*;
274        if dialect == Sqlite {
275            match self.on_conflict {
276                Ignore => buf.push_str("INSERT OR IGNORE INTO "),
277                Abort => buf.push_str("INSERT OR ABORT INTO "),
278                Replace => buf.push_str("INSERT OR REPLACE INTO "),
279                DoUpdateAllRows { .. } | DoUpdate { .. } => {
280                    panic!("Sqlite does not support ON CONFLICT DO UPDATE")
281                }
282            }
283        } else {
284            buf.push_str("INSERT INTO ");
285        }
286        buf.push_table_name(&self.schema, &self.table);
287        buf.push_str(" (");
288        buf.push_quoted_sequence(&self.columns, ", ");
289        buf.push_str(") VALUES ");
290        self.values.write_sql(buf, dialect);
291
292        if dialect == Postgres {
293            match &self.on_conflict {
294                Ignore => buf.push_str(" ON CONFLICT DO NOTHING"),
295                Abort => {}
296                Replace => panic!("Postgres does not support ON CONFLICT REPLACE"),
297                DoUpdate { conflict, updates } => {
298                    buf.push_str(" ON CONFLICT ");
299                    buf.push_sql(conflict, dialect);
300                    buf.push_str(" DO UPDATE SET ");
301                    let updates: Vec<Expr> = updates
302                        .into_iter()
303                        .map(|(c, v)| Expr::new_eq(Expr::column(c), v.clone()))
304                        .collect();
305                    buf.push_sql_sequence(&updates, ", ", dialect);
306                }
307                DoUpdateAllRows {
308                    conflict,
309                    alternate_values,
310                    ignore_columns,
311                } => {
312                    buf.push_str(" ON CONFLICT ");
313                    buf.push_sql(conflict, dialect);
314                    buf.push_str(" DO UPDATE SET ");
315                    let conflict_columns = conflict.as_columns();
316                    let columns: Vec<Expr> = self
317                        .columns
318                        .iter()
319                        .filter(|&c| !ignore_columns.contains(c))
320                        .filter(|&c| conflict_columns.map(|conflict| !conflict.contains(c)).unwrap_or(true))
321                        .map(|c| {
322                            let r = if let Some(v) = alternate_values.get(c) {
323                                v.clone()
324                            } else {
325                                Expr::excluded(c)
326                            };
327                            Expr::new_eq(Expr::column(c), r)
328                        })
329                        .collect();
330                    buf.push_sql_sequence(&columns, ", ", dialect);
331                }
332            }
333        }
334        if !self.returning.is_empty() {
335            buf.push_str(" RETURNING ");
336            buf.push_quoted_sequence(&self.returning, ", ");
337        }
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use pretty_assertions::assert_eq;
344    use super::*;
345    use crate::query::{Case, Expr};
346
347    #[test]
348    fn test_basic() {
349        let insert = Insert {
350            schema: None,
351            table: "foo".to_string(),
352            columns: vec!["bar".to_string(), "baz".to_string()],
353            values: Values::from(&[&["1", "2"] as &[&str], &["3", "4"]] as &[&[&str]]),
354            on_conflict: OnConflict::Abort,
355            returning: vec!["id".to_string()],
356        };
357        assert_eq!(
358            insert.to_sql(Dialect::Postgres),
359            r#"INSERT INTO "foo" ("bar", "baz") VALUES (1, 2), (3, 4) RETURNING "id""#
360        );
361    }
362
363    #[test]
364    fn test_placeholders() {
365        let insert = Insert::new("foo")
366            .columns(&["bar", "baz", "qux", "wibble", "wobble", "wubble"])
367            .placeholder_for_each_column(Dialect::Postgres)
368            .on_conflict(OnConflict::do_update_all_rows(&["bar"]));
369        let expected = r#"INSERT INTO "foo" ("bar", "baz", "qux", "wibble", "wobble", "wubble") VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT ("bar") DO UPDATE SET "baz" = excluded."baz", "qux" = excluded."qux", "wibble" = excluded."wibble", "wobble" = excluded."wobble", "wubble" = excluded."wubble""#;
370        assert_eq!(insert.to_sql(Dialect::Postgres), expected);
371    }
372
373    #[test]
374    fn test_override() {
375        let columns = &["id", "name", "email"];
376
377        let update_conditional = columns
378            .iter()
379            .map(|&c| {
380                Expr::not_distinct_from(
381                    Expr::table_column("users", c),
382                    Expr::excluded(c),
383                )
384            })
385            .collect::<Vec<_>>();
386        let on_conflict_update_value = Expr::case(
387            Case::new_when(
388                Expr::new_and(update_conditional),
389                Expr::table_column("users", "updated_at"),
390            )
391            .els("excluded.updated_at"),
392        );
393
394        let insert = Insert::new("users")
395            .columns(columns)
396            .column("updated_at")
397            .values(Values::new_value(Value::with(&[
398                "1",
399                "Kurt",
400                "test@example.com",
401                "NOW()",
402            ])))
403            .on_conflict(
404                OnConflict::do_update_on_pkey("id")
405                    .alternate_value("updated_at", on_conflict_update_value),
406            );
407        let sql = insert.to_sql(Dialect::Postgres);
408        let expected = r#"
409INSERT INTO "users" ("id", "name", "email", "updated_at") VALUES
410(1, Kurt, test@example.com, NOW())
411ON CONFLICT ("id") DO UPDATE SET
412"name" = excluded."name",
413"email" = excluded."email",
414"updated_at" = CASE WHEN
415("users"."id" IS NOT DISTINCT FROM excluded."id" AND
416"users"."name" IS NOT DISTINCT FROM excluded."name" AND
417"users"."email" IS NOT DISTINCT FROM excluded."email")
418THEN "users"."updated_at"
419ELSE excluded.updated_at END
420"#
421        .replace("\n", " ");
422        assert_eq!(sql, expected.trim());
423    }
424}