1use crate::query::Expr;
2use crate::util::SqlExtension;
3use crate::{Dialect, Select, ToSql};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum OnConflict {
8 Ignore,
9 Abort,
10 Replace,
12 DoUpdate {
14 conflict: Conflict,
15 updates: Vec<(String, Expr)>,
16 },
17 DoUpdateAllRows {
19 conflict: Conflict,
20 alternate_values: HashMap<String, Expr>,
21 ignore_columns: Vec<String>,
22 },
23}
24
25impl OnConflict {
26 pub fn do_update_all_rows(columns: &[&str]) -> Self {
27 OnConflict::DoUpdateAllRows {
28 conflict: Conflict::Columns(columns.iter().map(|c| c.to_string()).collect()),
29 alternate_values: HashMap::new(),
30 ignore_columns: Vec::new(),
31 }
32 }
33
34 pub fn do_update_on_pkey(pkey: &str) -> Self {
35 OnConflict::DoUpdateAllRows {
36 conflict: Conflict::Columns(vec![pkey.to_string()]),
37 alternate_values: HashMap::new(),
38 ignore_columns: Vec::new(),
39 }
40 }
41
42 pub fn alternate_value<V: Into<Expr>>(mut self, column: &str, value: V) -> Self {
43 match &mut self {
44 OnConflict::DoUpdateAllRows {
45 alternate_values, ..
46 } => {
47 alternate_values.insert(column.to_string(), value.into());
48 }
49 _ => panic!("alternate_value is only valid for DoUpdate"),
50 }
51 self
52 }
53}
54
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum Conflict {
57 Columns(Vec<String>),
58 ConstraintName(String),
59 NoTarget,
60}
61
62impl Conflict {
63 pub fn columns(t: impl IntoIterator<Item = impl Into<String>>) -> Self {
64 Conflict::Columns(t.into_iter().map(|c| c.into()).collect())
65 }
66
67 pub fn as_columns(&self) -> Option<&Vec<String>> {
68 match self {
69 Conflict::Columns(c) => Some(c),
70 _ => None,
71 }
72 }
73}
74
75impl Default for OnConflict {
76 fn default() -> Self {
77 OnConflict::Abort
78 }
79}
80
81impl ToSql for Conflict {
82 fn write_sql(&self, buf: &mut String, _dialect: Dialect) {
83 match self {
84 Conflict::Columns(c) => {
85 buf.push('(');
86 buf.push_quoted_sequence(c, ", ");
87 buf.push(')');
88 }
89 Conflict::ConstraintName(name) => {
90 buf.push_str("ON CONSTRAINT ");
91 buf.push_quoted(name);
92 }
93 Conflict::NoTarget => {}
94 }
95 }
96}
97
98impl ToSql for Values {
99 fn write_sql(&self, buf: &mut String, dialect: Dialect) {
100 match self {
101 Values::Values(values) => {
102 let mut first_value = true;
103 for value in values {
104 if !first_value {
105 buf.push_str(", ");
106 }
107 let mut first = true;
108 buf.push('(');
109 for v in &value.0 {
110 if !first {
111 buf.push_str(", ");
112 }
113 buf.push_str(v);
114 first = false;
115 }
116 buf.push(')');
117 first_value = false;
118 }
119 }
120 Values::Select(select) => {
121 buf.push_sql(select, dialect);
122 }
123 Values::DefaultValues => {
124 buf.push_str("DEFAULT VALUES");
125 }
126 }
127 }
128}
129
130#[derive(Debug, Clone, PartialEq, Eq)]
131pub struct Value(Vec<String>);
132
133impl Value {
134 pub fn with(values: &[&str]) -> Self {
135 Self(values.into_iter().map(|v| v.to_string()).collect())
136 }
137
138 pub fn new() -> Self {
139 Self(Vec::new())
140 }
141
142 pub fn column(mut self, value: &str) -> Self {
143 self.0.push(value.to_string());
144 self
145 }
146
147 pub fn placeholders(mut self, count: usize, dialect: Dialect) -> Self {
148 use Dialect::*;
149 for i in 1..(count + 1) {
150 match dialect {
151 Postgres => self.0.push(format!("${}", i)),
152 Mysql | Sqlite => self.0.push("?".to_string()),
153 }
154 }
155 self
156 }
157}
158
159impl From<Vec<String>> for Value {
160 fn from(values: Vec<String>) -> Self {
161 Self(values)
162 }
163}
164
165#[derive(Debug, Clone, PartialEq, Eq)]
166pub enum Values {
167 Values(Vec<Value>),
168 Select(Select),
169 DefaultValues,
170}
171
172impl From<&[&[&'static str]]> for Values {
173 fn from(values: &[&[&'static str]]) -> Self {
174 Self::Values(values.into_iter().map(|v| Value::with(v)).collect())
175 }
176}
177
178impl From<&[&'static str]> for Values {
179 fn from(values: &[&'static str]) -> Self {
180 Self::Values(vec![Value::with(values)])
181 }
182}
183
184impl Values {
185 pub fn new_value(value: Value) -> Self {
186 Self::Values(vec![value])
187 }
188
189 pub fn select(select: Select) -> Self {
190 Self::Select(select)
191 }
192
193 pub fn default_values() -> Self {
194 Self::DefaultValues
195 }
196
197 pub fn value(mut self, value: Value) -> Self {
198 match &mut self {
199 Self::Values(values) => values.push(value),
200 _ => panic!("Cannot add value to non-values"),
201 }
202 self
203 }
204}
205
206#[derive(Debug, Clone, PartialEq, Eq)]
207pub struct Insert {
208 pub schema: Option<String>,
209 pub table: String,
210 pub columns: Vec<String>,
211 pub values: Values,
212 pub on_conflict: OnConflict,
213 pub returning: Vec<String>,
214}
215
216impl Insert {
217 pub fn new(table: &str) -> Self {
218 Self {
219 schema: None,
220 table: table.to_string(),
221 columns: Vec::new(),
222 values: Values::DefaultValues,
223 on_conflict: OnConflict::default(),
224 returning: Vec::new(),
225 }
226 }
227
228 pub fn schema(mut self, schema: &str) -> Self {
229 self.schema = Some(schema.to_string());
230 self
231 }
232
233 pub fn column(mut self, column: &str) -> Self {
234 self.columns.push(column.to_string());
235 self
236 }
237
238 pub fn values(mut self, value: Values) -> Self {
239 self.values = value;
240 self
241 }
242
243 pub fn columns(mut self, columns: &[&str]) -> Self {
244 self.columns = columns.iter().map(|c| c.to_string()).collect();
245 self
246 }
247
248 pub fn placeholder_for_each_column(mut self, dialect: Dialect) -> Self {
249 self.values = Values::new_value(Value::new().placeholders(self.columns.len(), dialect));
250 self
251 }
252
253 #[deprecated(note = "Use .values(Values::from(...)) instead")]
254 pub fn one_value(mut self, values: &[&str]) -> Self {
255 self.values = Values::Values(vec![Value::with(values)]);
256 self
257 }
258
259 pub fn on_conflict(mut self, on_conflict: OnConflict) -> Self {
260 self.on_conflict = on_conflict;
261 self
262 }
263
264 pub fn returning(mut self, returning: &[&str]) -> Self {
265 self.returning = returning.iter().map(|r| r.to_string()).collect();
266 self
267 }
268}
269
270impl ToSql for Insert {
271 fn write_sql(&self, buf: &mut String, dialect: Dialect) {
272 use Dialect::*;
273 use OnConflict::*;
274 if dialect == Sqlite {
275 match self.on_conflict {
276 Ignore => buf.push_str("INSERT OR IGNORE INTO "),
277 Abort => buf.push_str("INSERT OR ABORT INTO "),
278 Replace => buf.push_str("INSERT OR REPLACE INTO "),
279 DoUpdateAllRows { .. } | DoUpdate { .. } => {
280 panic!("Sqlite does not support ON CONFLICT DO UPDATE")
281 }
282 }
283 } else {
284 buf.push_str("INSERT INTO ");
285 }
286 buf.push_table_name(&self.schema, &self.table);
287 buf.push_str(" (");
288 buf.push_quoted_sequence(&self.columns, ", ");
289 buf.push_str(") VALUES ");
290 self.values.write_sql(buf, dialect);
291
292 if dialect == Postgres {
293 match &self.on_conflict {
294 Ignore => buf.push_str(" ON CONFLICT DO NOTHING"),
295 Abort => {}
296 Replace => panic!("Postgres does not support ON CONFLICT REPLACE"),
297 DoUpdate { conflict, updates } => {
298 buf.push_str(" ON CONFLICT ");
299 buf.push_sql(conflict, dialect);
300 buf.push_str(" DO UPDATE SET ");
301 let updates: Vec<Expr> = updates
302 .into_iter()
303 .map(|(c, v)| Expr::new_eq(Expr::column(c), v.clone()))
304 .collect();
305 buf.push_sql_sequence(&updates, ", ", dialect);
306 }
307 DoUpdateAllRows {
308 conflict,
309 alternate_values,
310 ignore_columns,
311 } => {
312 buf.push_str(" ON CONFLICT ");
313 buf.push_sql(conflict, dialect);
314 buf.push_str(" DO UPDATE SET ");
315 let conflict_columns = conflict.as_columns();
316 let columns: Vec<Expr> = self
317 .columns
318 .iter()
319 .filter(|&c| !ignore_columns.contains(c))
320 .filter(|&c| conflict_columns.map(|conflict| !conflict.contains(c)).unwrap_or(true))
321 .map(|c| {
322 let r = if let Some(v) = alternate_values.get(c) {
323 v.clone()
324 } else {
325 Expr::excluded(c)
326 };
327 Expr::new_eq(Expr::column(c), r)
328 })
329 .collect();
330 buf.push_sql_sequence(&columns, ", ", dialect);
331 }
332 }
333 }
334 if !self.returning.is_empty() {
335 buf.push_str(" RETURNING ");
336 buf.push_quoted_sequence(&self.returning, ", ");
337 }
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use pretty_assertions::assert_eq;
344 use super::*;
345 use crate::query::{Case, Expr};
346
347 #[test]
348 fn test_basic() {
349 let insert = Insert {
350 schema: None,
351 table: "foo".to_string(),
352 columns: vec!["bar".to_string(), "baz".to_string()],
353 values: Values::from(&[&["1", "2"] as &[&str], &["3", "4"]] as &[&[&str]]),
354 on_conflict: OnConflict::Abort,
355 returning: vec!["id".to_string()],
356 };
357 assert_eq!(
358 insert.to_sql(Dialect::Postgres),
359 r#"INSERT INTO "foo" ("bar", "baz") VALUES (1, 2), (3, 4) RETURNING "id""#
360 );
361 }
362
363 #[test]
364 fn test_placeholders() {
365 let insert = Insert::new("foo")
366 .columns(&["bar", "baz", "qux", "wibble", "wobble", "wubble"])
367 .placeholder_for_each_column(Dialect::Postgres)
368 .on_conflict(OnConflict::do_update_all_rows(&["bar"]));
369 let expected = r#"INSERT INTO "foo" ("bar", "baz", "qux", "wibble", "wobble", "wubble") VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT ("bar") DO UPDATE SET "baz" = excluded."baz", "qux" = excluded."qux", "wibble" = excluded."wibble", "wobble" = excluded."wobble", "wubble" = excluded."wubble""#;
370 assert_eq!(insert.to_sql(Dialect::Postgres), expected);
371 }
372
373 #[test]
374 fn test_override() {
375 let columns = &["id", "name", "email"];
376
377 let update_conditional = columns
378 .iter()
379 .map(|&c| {
380 Expr::not_distinct_from(
381 Expr::table_column("users", c),
382 Expr::excluded(c),
383 )
384 })
385 .collect::<Vec<_>>();
386 let on_conflict_update_value = Expr::case(
387 Case::new_when(
388 Expr::new_and(update_conditional),
389 Expr::table_column("users", "updated_at"),
390 )
391 .els("excluded.updated_at"),
392 );
393
394 let insert = Insert::new("users")
395 .columns(columns)
396 .column("updated_at")
397 .values(Values::new_value(Value::with(&[
398 "1",
399 "Kurt",
400 "test@example.com",
401 "NOW()",
402 ])))
403 .on_conflict(
404 OnConflict::do_update_on_pkey("id")
405 .alternate_value("updated_at", on_conflict_update_value),
406 );
407 let sql = insert.to_sql(Dialect::Postgres);
408 let expected = r#"
409INSERT INTO "users" ("id", "name", "email", "updated_at") VALUES
410(1, Kurt, test@example.com, NOW())
411ON CONFLICT ("id") DO UPDATE SET
412"name" = excluded."name",
413"email" = excluded."email",
414"updated_at" = CASE WHEN
415("users"."id" IS NOT DISTINCT FROM excluded."id" AND
416"users"."name" IS NOT DISTINCT FROM excluded."name" AND
417"users"."email" IS NOT DISTINCT FROM excluded."email")
418THEN "users"."updated_at"
419ELSE excluded.updated_at END
420"#
421 .replace("\n", " ");
422 assert_eq!(sql, expected.trim());
423 }
424}