sqlint/visitor/
postgres.rs

1use crate::{
2    ast::*,
3    visitor::{self, Visitor},
4};
5use std::fmt::{self, Write};
6
7/// A visitor to generate queries for the PostgreSQL database.
8///
9/// The returned parameter values implement the `ToSql` trait from postgres and
10/// can be used directly with the database.
11#[cfg_attr(feature = "docs", doc(cfg(feature = "postgresql")))]
12pub struct Postgres<'a> {
13    query: String,
14    parameters: Vec<Value<'a>>,
15}
16
17impl<'a> Visitor<'a> for Postgres<'a> {
18    const C_BACKTICK_OPEN: &'static str = "\"";
19    const C_BACKTICK_CLOSE: &'static str = "\"";
20    const C_WILDCARD: &'static str = "%";
21
22    fn build<Q>(query: Q) -> crate::Result<(String, Vec<Value<'a>>)>
23    where
24        Q: Into<Query<'a>>,
25    {
26        let mut postgres = Postgres { query: String::with_capacity(4096), parameters: Vec::with_capacity(128) };
27
28        Postgres::visit_query(&mut postgres, query.into())?;
29
30        Ok((postgres.query, postgres.parameters))
31    }
32
33    fn write<D: fmt::Display>(&mut self, s: D) -> visitor::Result {
34        write!(&mut self.query, "{s}")?;
35        Ok(())
36    }
37
38    fn add_parameter(&mut self, value: Value<'a>) {
39        self.parameters.push(value);
40    }
41
42    fn parameter_substitution(&mut self) -> visitor::Result {
43        self.write("$")?;
44        self.write(self.parameters.len())
45    }
46
47    fn visit_limit_and_offset(&mut self, limit: Option<Value<'a>>, offset: Option<Value<'a>>) -> visitor::Result {
48        match (limit, offset) {
49            (Some(limit), Some(offset)) => {
50                self.write(" LIMIT ")?;
51                self.visit_parameterized(limit)?;
52
53                self.write(" OFFSET ")?;
54                self.visit_parameterized(offset)
55            }
56            (None, Some(offset)) => {
57                self.write(" OFFSET ")?;
58                self.visit_parameterized(offset)
59            }
60            (Some(limit), None) => {
61                self.write(" LIMIT ")?;
62                self.visit_parameterized(limit)
63            }
64            (None, None) => Ok(()),
65        }
66    }
67
68    fn visit_raw_value(&mut self, value: Value<'a>) -> visitor::Result {
69        let res = match value {
70            Value::Int32(i) => i.map(|i| self.write(i)),
71            Value::Int64(i) => i.map(|i| self.write(i)),
72            Value::Text(t) => t.map(|t| self.write(format!("'{t}'"))),
73            Value::Enum(e) => e.map(|e| self.write(e)),
74            Value::Bytes(b) => b.map(|b| self.write(format!("E'{}'", hex::encode(b)))),
75            Value::Boolean(b) => b.map(|b| self.write(b)),
76            Value::Xml(cow) => cow.map(|cow| self.write(format!("'{cow}'"))),
77            Value::Char(c) => c.map(|c| self.write(format!("'{c}'"))),
78            Value::Float(d) => d.map(|f| match f {
79                f if f.is_nan() => self.write("'NaN'"),
80                f if f == f32::INFINITY => self.write("'Infinity'"),
81                f if f == f32::NEG_INFINITY => self.write("'-Infinity"),
82                v => self.write(format!("{v:?}")),
83            }),
84            Value::Double(d) => d.map(|f| match f {
85                f if f.is_nan() => self.write("'NaN'"),
86                f if f == f64::INFINITY => self.write("'Infinity'"),
87                f if f == f64::NEG_INFINITY => self.write("'-Infinity"),
88                v => self.write(format!("{v:?}")),
89            }),
90            Value::Array(ary) => ary.map(|ary| {
91                self.surround_with("'{", "}'", |ref mut s| {
92                    let len = ary.len();
93
94                    for (i, item) in ary.into_iter().enumerate() {
95                        s.write(item)?;
96
97                        if i < len - 1 {
98                            s.write(",")?;
99                        }
100                    }
101
102                    Ok(())
103                })
104            }),
105            #[cfg(feature = "json")]
106            Value::Json(j) => j.map(|j| self.write(format!("'{}'", serde_json::to_string(&j).unwrap()))),
107            #[cfg(feature = "bigdecimal")]
108            Value::Numeric(r) => r.map(|r| self.write(r)),
109            #[cfg(feature = "uuid")]
110            Value::Uuid(uuid) => uuid.map(|uuid| self.write(format!("'{}'", uuid.hyphenated()))),
111            #[cfg(feature = "chrono")]
112            Value::DateTime(dt) => dt.map(|dt| self.write(format!("'{}'", dt.to_rfc3339(),))),
113            #[cfg(feature = "chrono")]
114            Value::Date(date) => date.map(|date| self.write(format!("'{date}'"))),
115            #[cfg(feature = "chrono")]
116            Value::Time(time) => time.map(|time| self.write(format!("'{time}'"))),
117        };
118
119        match res {
120            Some(res) => res,
121            None => self.write("null"),
122        }
123    }
124
125    fn visit_insert(&mut self, insert: Insert<'a>) -> visitor::Result {
126        self.write("INSERT ")?;
127
128        if let Some(table) = insert.table.clone() {
129            self.write("INTO ")?;
130            self.visit_table(table, true)?;
131        }
132
133        match insert.values {
134            Expression { kind: ExpressionKind::Row(row), .. } => {
135                if row.values.is_empty() {
136                    self.write(" DEFAULT VALUES")?;
137                } else {
138                    let columns = insert.columns.len();
139
140                    self.write(" (")?;
141                    for (i, c) in insert.columns.into_iter().enumerate() {
142                        self.visit_column(c.name.into_owned().into())?;
143
144                        if i < (columns - 1) {
145                            self.write(",")?;
146                        }
147                    }
148
149                    self.write(")")?;
150                    self.write(" VALUES ")?;
151                    self.visit_row(row)?;
152                }
153            }
154            Expression { kind: ExpressionKind::Values(values), .. } => {
155                let columns = insert.columns.len();
156
157                self.write(" (")?;
158                for (i, c) in insert.columns.into_iter().enumerate() {
159                    self.visit_column(c.name.into_owned().into())?;
160
161                    if i < (columns - 1) {
162                        self.write(",")?;
163                    }
164                }
165
166                self.write(")")?;
167                self.write(" VALUES ")?;
168                let values_len = values.len();
169
170                for (i, row) in values.into_iter().enumerate() {
171                    self.visit_row(row)?;
172
173                    if i < (values_len - 1) {
174                        self.write(", ")?;
175                    }
176                }
177            }
178            expr => self.surround_with("(", ")", |ref mut s| s.visit_expression(expr))?,
179        }
180
181        match insert.on_conflict {
182            Some(OnConflict::DoNothing) => self.write(" ON CONFLICT DO NOTHING")?,
183            Some(OnConflict::Update(update, constraints)) => {
184                self.write(" ON CONFLICT")?;
185                self.columns_to_bracket_list(constraints)?;
186                self.write(" DO ")?;
187
188                self.visit_upsert(update)?;
189            }
190            None => (),
191        }
192
193        if let Some(returning) = insert.returning {
194            if !returning.is_empty() {
195                let values = returning.into_iter().map(|r| r.into()).collect();
196                self.write(" RETURNING ")?;
197                self.visit_columns(values)?;
198            }
199        };
200
201        if let Some(comment) = insert.comment {
202            self.write(" ")?;
203            self.visit_comment(comment)?;
204        }
205
206        Ok(())
207    }
208
209    fn visit_aggregate_to_string(&mut self, value: Expression<'a>) -> visitor::Result {
210        self.write("ARRAY_TO_STRING")?;
211        self.write("(")?;
212        self.write("ARRAY_AGG")?;
213        self.write("(")?;
214        self.visit_expression(value)?;
215        self.write(")")?;
216        self.write("','")?;
217        self.write(")")
218    }
219
220    fn visit_equals(&mut self, left: Expression<'a>, right: Expression<'a>) -> visitor::Result {
221        // LHS must be cast to json/xml-text if the right is a json/xml-text value and vice versa.
222        let right_cast = match left {
223            #[cfg(feature = "json")]
224            _ if left.is_json_value() => "::jsonb",
225            _ if left.is_xml_value() => "::text",
226            _ => "",
227        };
228
229        let left_cast = match right {
230            #[cfg(feature = "json")]
231            _ if right.is_json_value() => "::jsonb",
232            _ if right.is_xml_value() => "::text",
233            _ => "",
234        };
235
236        self.visit_expression(left)?;
237        self.write(left_cast)?;
238        self.write(" = ")?;
239        self.visit_expression(right)?;
240        self.write(right_cast)?;
241
242        Ok(())
243    }
244
245    fn visit_not_equals(&mut self, left: Expression<'a>, right: Expression<'a>) -> visitor::Result {
246        // LHS must be cast to json/xml-text if the right is a json/xml-text value and vice versa.
247        let right_cast = match left {
248            #[cfg(feature = "json")]
249            _ if left.is_json_value() => "::jsonb",
250            _ if left.is_xml_value() => "::text",
251            _ => "",
252        };
253
254        let left_cast = match right {
255            #[cfg(feature = "json")]
256            _ if right.is_json_value() => "::jsonb",
257            _ if right.is_xml_value() => "::text",
258            _ => "",
259        };
260
261        self.visit_expression(left)?;
262        self.write(left_cast)?;
263        self.write(" <> ")?;
264        self.visit_expression(right)?;
265        self.write(right_cast)?;
266
267        Ok(())
268    }
269
270    #[cfg(all(feature = "json", any(feature = "postgresql", feature = "mysql")))]
271    fn visit_json_extract(&mut self, json_extract: JsonExtract<'a>) -> visitor::Result {
272        match json_extract.path {
273            #[cfg(feature = "mysql")]
274            JsonPath::String(_) => panic!("JSON path string notation is not supported for Postgres"),
275            JsonPath::Array(json_path) => {
276                self.write("(")?;
277                self.visit_expression(*json_extract.column)?;
278
279                if json_extract.extract_as_string {
280                    self.write("#>>")?;
281                } else {
282                    self.write("#>")?;
283                }
284
285                // We use the `ARRAY[]::text[]` notation to better handle escaped character
286                // The text protocol used when sending prepared statement doesn't seem to work well with escaped characters
287                // when using the '{a, b, c}' string array notation.
288                self.surround_with("ARRAY[", "]::text[]", |s| {
289                    let len = json_path.len();
290                    for (index, path) in json_path.into_iter().enumerate() {
291                        s.visit_parameterized(Value::text(path))?;
292                        if index < len - 1 {
293                            s.write(", ")?;
294                        }
295                    }
296                    Ok(())
297                })?;
298
299                self.write(")")?;
300
301                if !json_extract.extract_as_string {
302                    self.write("::jsonb")?;
303                }
304            }
305        }
306
307        Ok(())
308    }
309
310    #[cfg(all(feature = "json", any(feature = "postgresql", feature = "mysql")))]
311    fn visit_json_unquote(&mut self, json_unquote: JsonUnquote<'a>) -> visitor::Result {
312        self.write("(")?;
313        self.visit_expression(*json_unquote.expr)?;
314        self.write("#>>ARRAY[]::text[]")?;
315        self.write(")")?;
316
317        Ok(())
318    }
319
320    #[cfg(all(feature = "json", any(feature = "postgresql", feature = "mysql")))]
321    fn visit_json_array_contains(&mut self, left: Expression<'a>, right: Expression<'a>, not: bool) -> visitor::Result {
322        if not {
323            self.write("( NOT ")?;
324        }
325
326        self.visit_expression(left)?;
327        self.write(" @> ")?;
328        self.visit_expression(right)?;
329
330        if not {
331            self.write(" )")?;
332        }
333
334        Ok(())
335    }
336
337    #[cfg(all(feature = "json", any(feature = "postgresql", feature = "mysql")))]
338    fn visit_json_extract_last_array_item(&mut self, extract: JsonExtractLastArrayElem<'a>) -> visitor::Result {
339        self.write("(")?;
340        self.visit_expression(*extract.expr)?;
341        self.write("->-1")?;
342        self.write(")")?;
343
344        Ok(())
345    }
346
347    #[cfg(all(feature = "json", any(feature = "postgresql", feature = "mysql")))]
348    fn visit_json_extract_first_array_item(&mut self, extract: JsonExtractFirstArrayElem<'a>) -> visitor::Result {
349        self.write("(")?;
350        self.visit_expression(*extract.expr)?;
351        self.write("->0")?;
352        self.write(")")?;
353
354        Ok(())
355    }
356
357    #[cfg(all(feature = "json", any(feature = "postgresql", feature = "mysql")))]
358    fn visit_json_type_equals(&mut self, left: Expression<'a>, json_type: JsonType<'a>, not: bool) -> visitor::Result {
359        self.write("JSONB_TYPEOF")?;
360        self.write("(")?;
361        self.visit_expression(left)?;
362        self.write(")")?;
363
364        if not {
365            self.write(" != ")?;
366        } else {
367            self.write(" = ")?;
368        }
369
370        match json_type {
371            JsonType::Array => self.visit_expression(Value::text("array").into()),
372            JsonType::Boolean => self.visit_expression(Value::text("boolean").into()),
373            JsonType::Number => self.visit_expression(Value::text("number").into()),
374            JsonType::Object => self.visit_expression(Value::text("object").into()),
375            JsonType::String => self.visit_expression(Value::text("string").into()),
376            JsonType::Null => self.visit_expression(Value::text("null").into()),
377            JsonType::ColumnRef(column) => {
378                self.write("JSONB_TYPEOF")?;
379                self.write("(")?;
380                self.visit_column(*column)?;
381                self.write("::jsonb)")
382            }
383        }
384    }
385
386    fn visit_text_search(&mut self, text_search: crate::prelude::TextSearch<'a>) -> visitor::Result {
387        let len = text_search.exprs.len();
388        self.surround_with("to_tsvector(concat_ws(' ', ", "))", |s| {
389            for (i, expr) in text_search.exprs.into_iter().enumerate() {
390                s.visit_expression(expr)?;
391
392                if i < (len - 1) {
393                    s.write(",")?;
394                }
395            }
396
397            Ok(())
398        })
399    }
400
401    fn visit_matches(&mut self, left: Expression<'a>, right: std::borrow::Cow<'a, str>, not: bool) -> visitor::Result {
402        if not {
403            self.write("(NOT ")?;
404        }
405
406        self.visit_expression(left)?;
407        self.write(" @@ ")?;
408        self.surround_with("to_tsquery(", ")", |s| s.visit_parameterized(Value::text(right)))?;
409
410        if not {
411            self.write(")")?;
412        }
413
414        Ok(())
415    }
416
417    fn visit_text_search_relevance(&mut self, text_search_relevance: TextSearchRelevance<'a>) -> visitor::Result {
418        let len = text_search_relevance.exprs.len();
419        let exprs = text_search_relevance.exprs;
420        let query = text_search_relevance.query;
421
422        self.write("ts_rank(")?;
423        self.surround_with("to_tsvector(concat_ws(' ', ", "))", |s| {
424            for (i, expr) in exprs.into_iter().enumerate() {
425                s.visit_expression(expr)?;
426
427                if i < (len - 1) {
428                    s.write(",")?;
429                }
430            }
431
432            Ok(())
433        })?;
434        self.write(", ")?;
435        self.surround_with("to_tsquery(", ")", |s| s.visit_parameterized(Value::text(query)))?;
436        self.write(")")?;
437
438        Ok(())
439    }
440
441    fn visit_like(&mut self, left: Expression<'a>, right: Expression<'a>) -> visitor::Result {
442        let need_cast = matches!(&left.kind, ExpressionKind::Column(_));
443        self.visit_expression(left)?;
444
445        // NOTE: Pg is strongly typed, LIKE comparisons are only between strings.
446        // to avoid problems with types without implicit casting we explicitly cast to text
447        if need_cast {
448            self.write("::text")?;
449        }
450
451        self.write(" LIKE ")?;
452        self.visit_expression(right)?;
453
454        Ok(())
455    }
456
457    fn visit_not_like(&mut self, left: Expression<'a>, right: Expression<'a>) -> visitor::Result {
458        let need_cast = matches!(&left.kind, ExpressionKind::Column(_));
459        self.visit_expression(left)?;
460
461        // NOTE: Pg is strongly typed, LIKE comparisons are only between strings.
462        // to avoid problems with types without implicit casting we explicitly cast to text
463        if need_cast {
464            self.write("::text")?;
465        }
466
467        self.write(" NOT LIKE ")?;
468        self.visit_expression(right)?;
469
470        Ok(())
471    }
472
473    fn visit_ordering(&mut self, ordering: Ordering<'a>) -> visitor::Result {
474        let len = ordering.0.len();
475
476        for (i, (value, ordering)) in ordering.0.into_iter().enumerate() {
477            let direction = ordering.map(|dir| match dir {
478                Order::Asc => " ASC",
479                Order::Desc => " DESC",
480                Order::AscNullsFirst => "ASC NULLS FIRST",
481                Order::AscNullsLast => "ASC NULLS LAST",
482                Order::DescNullsFirst => "DESC NULLS FIRST",
483                Order::DescNullsLast => "DESC NULLS LAST",
484            });
485
486            self.visit_expression(value)?;
487            self.write(direction.unwrap_or(""))?;
488
489            if i < (len - 1) {
490                self.write(", ")?;
491            }
492        }
493
494        Ok(())
495    }
496
497    fn visit_concat(&mut self, concat: Concat<'a>) -> visitor::Result {
498        let len = concat.exprs.len();
499
500        self.surround_with("(", ")", |s| {
501            for (i, expr) in concat.exprs.into_iter().enumerate() {
502                s.visit_expression(expr)?;
503
504                if i < (len - 1) {
505                    s.write(" || ")?;
506                }
507            }
508
509            Ok(())
510        })?;
511
512        Ok(())
513    }
514}
515
516#[cfg(test)]
517mod tests {
518    use crate::visitor::*;
519
520    fn expected_values<'a, T>(sql: &'static str, params: Vec<T>) -> (String, Vec<Value<'a>>)
521    where
522        T: Into<Value<'a>>,
523    {
524        (String::from(sql), params.into_iter().map(|p| p.into()).collect())
525    }
526
527    fn default_params<'a>(mut additional: Vec<Value<'a>>) -> Vec<Value<'a>> {
528        let mut result = Vec::new();
529
530        for param in additional.drain(0..) {
531            result.push(param)
532        }
533
534        result
535    }
536
537    #[test]
538    fn test_single_row_insert_default_values() {
539        let query = Insert::single_into("users");
540        let (sql, params) = Postgres::build(query).unwrap();
541
542        assert_eq!("INSERT INTO \"users\" DEFAULT VALUES", sql);
543        assert_eq!(default_params(vec![]), params);
544    }
545
546    #[test]
547    fn test_single_row_insert() {
548        let expected = expected_values("INSERT INTO \"users\" (\"foo\") VALUES ($1)", vec![10]);
549        let query = Insert::single_into("users").value("foo", 10);
550        let (sql, params) = Postgres::build(query).unwrap();
551
552        assert_eq!(expected.0, sql);
553        assert_eq!(expected.1, params);
554    }
555
556    #[test]
557    #[cfg(feature = "postgresql")]
558    fn test_returning_insert() {
559        let expected = expected_values("INSERT INTO \"users\" (\"foo\") VALUES ($1) RETURNING \"foo\"", vec![10]);
560        let query = Insert::single_into("users").value("foo", 10);
561        let (sql, params) = Postgres::build(Insert::from(query).returning(vec!["foo"])).unwrap();
562
563        assert_eq!(expected.0, sql);
564        assert_eq!(expected.1, params);
565    }
566
567    #[test]
568    #[cfg(feature = "postgresql")]
569    fn test_insert_on_conflict_update() {
570        let expected = expected_values(
571            "INSERT INTO \"users\" (\"foo\") VALUES ($1) ON CONFLICT (\"foo\") DO UPDATE SET \"foo\" = $2 WHERE \"users\".\"foo\" = $3 RETURNING \"foo\"",
572            vec![10, 3, 1],
573        );
574
575        let update = Update::table("users").set("foo", 3).so_that(("users", "foo").equals(1));
576
577        let query: Insert = Insert::single_into("users").value("foo", 10).into();
578
579        let query = query.on_conflict(OnConflict::Update(update, Vec::from(["foo".into()])));
580
581        let (sql, params) = Postgres::build(query.returning(vec!["foo"])).unwrap();
582
583        assert_eq!(expected.0, sql);
584        assert_eq!(expected.1, params);
585    }
586
587    #[test]
588    fn test_multi_row_insert() {
589        let expected = expected_values("INSERT INTO \"users\" (\"foo\") VALUES ($1), ($2)", vec![10, 11]);
590        let query = Insert::multi_into("users", vec!["foo"]).values(vec![10]).values(vec![11]);
591        let (sql, params) = Postgres::build(query).unwrap();
592
593        assert_eq!(expected.0, sql);
594        assert_eq!(expected.1, params);
595    }
596
597    #[test]
598    fn test_limit_and_offset_when_both_are_set() {
599        let expected = expected_values("SELECT \"users\".* FROM \"users\" LIMIT $1 OFFSET $2", vec![10_i64, 2_i64]);
600        let query = Select::from_table("users").limit(10).offset(2);
601        let (sql, params) = Postgres::build(query).unwrap();
602
603        assert_eq!(expected.0, sql);
604        assert_eq!(expected.1, params);
605    }
606
607    #[test]
608    fn test_limit_and_offset_when_only_offset_is_set() {
609        let expected = expected_values("SELECT \"users\".* FROM \"users\" OFFSET $1", vec![10_i64]);
610        let query = Select::from_table("users").offset(10);
611        let (sql, params) = Postgres::build(query).unwrap();
612
613        assert_eq!(expected.0, sql);
614        assert_eq!(expected.1, params);
615    }
616
617    #[test]
618    fn test_limit_and_offset_when_only_limit_is_set() {
619        let expected = expected_values("SELECT \"users\".* FROM \"users\" LIMIT $1", vec![10_i64]);
620        let query = Select::from_table("users").limit(10);
621        let (sql, params) = Postgres::build(query).unwrap();
622
623        assert_eq!(expected.0, sql);
624        assert_eq!(expected.1, params);
625    }
626
627    #[test]
628    fn test_distinct() {
629        let expected_sql = "SELECT DISTINCT \"bar\" FROM \"test\"";
630        let query = Select::from_table("test").column(Column::new("bar")).distinct();
631        let (sql, _) = Postgres::build(query).unwrap();
632
633        assert_eq!(expected_sql, sql);
634    }
635
636    #[test]
637    fn test_distinct_with_subquery() {
638        let expected_sql = "SELECT DISTINCT (SELECT $1 FROM \"test2\"), \"bar\" FROM \"test\"";
639        let query = Select::from_table("test")
640            .value(Select::from_table("test2").value(val!(1)))
641            .column(Column::new("bar"))
642            .distinct();
643
644        let (sql, _) = Postgres::build(query).unwrap();
645
646        assert_eq!(expected_sql, sql);
647    }
648
649    #[test]
650    fn test_from() {
651        let expected_sql = "SELECT \"foo\".*, \"bar\".\"a\" FROM \"foo\", (SELECT \"a\" FROM \"baz\") AS \"bar\"";
652        let query = Select::default()
653            .and_from("foo")
654            .and_from(Table::from(Select::from_table("baz").column("a")).alias("bar"))
655            .value(Table::from("foo").asterisk())
656            .column(("bar", "a"));
657
658        let (sql, _) = Postgres::build(query).unwrap();
659        assert_eq!(expected_sql, sql);
660    }
661
662    #[test]
663    fn test_comment_select() {
664        let expected_sql = "SELECT \"users\".* FROM \"users\" /* trace_id='5bd66ef5095369c7b0d1f8f4bd33716a', parent_id='c532cb4098ac3dd2' */";
665        let query = Select::from_table("users")
666            .comment("trace_id='5bd66ef5095369c7b0d1f8f4bd33716a', parent_id='c532cb4098ac3dd2'");
667
668        let (sql, _) = Postgres::build(query).unwrap();
669
670        assert_eq!(expected_sql, sql);
671    }
672
673    #[test]
674    fn test_comment_insert() {
675        let expected_sql = "INSERT INTO \"users\" DEFAULT VALUES /* trace_id='5bd66ef5095369c7b0d1f8f4bd33716a', parent_id='c532cb4098ac3dd2' */";
676        let query = Insert::single_into("users");
677        let insert =
678            Insert::from(query).comment("trace_id='5bd66ef5095369c7b0d1f8f4bd33716a', parent_id='c532cb4098ac3dd2'");
679
680        let (sql, _) = Postgres::build(insert).unwrap();
681
682        assert_eq!(expected_sql, sql);
683    }
684
685    #[test]
686    fn test_comment_update() {
687        let expected_sql = "UPDATE \"users\" SET \"foo\" = $1 /* trace_id='5bd66ef5095369c7b0d1f8f4bd33716a', parent_id='c532cb4098ac3dd2' */";
688        let query = Update::table("users")
689            .set("foo", 10)
690            .comment("trace_id='5bd66ef5095369c7b0d1f8f4bd33716a', parent_id='c532cb4098ac3dd2'");
691
692        let (sql, _) = Postgres::build(query).unwrap();
693
694        assert_eq!(expected_sql, sql);
695    }
696
697    #[test]
698    fn test_comment_delete() {
699        let expected_sql =
700            "DELETE FROM \"users\" /* trace_id='5bd66ef5095369c7b0d1f8f4bd33716a', parent_id='c532cb4098ac3dd2' */";
701        let query = Delete::from_table("users")
702            .comment("trace_id='5bd66ef5095369c7b0d1f8f4bd33716a', parent_id='c532cb4098ac3dd2'");
703
704        let (sql, _) = Postgres::build(query).unwrap();
705
706        assert_eq!(expected_sql, sql);
707    }
708
709    #[cfg(feature = "json")]
710    #[test]
711    fn equality_with_a_json_value() {
712        let expected = expected_values(
713            r#"SELECT "users".* FROM "users" WHERE "jsonField"::jsonb = $1"#,
714            vec![serde_json::json!({"a": "b"})],
715        );
716
717        let query = Select::from_table("users").so_that(Column::from("jsonField").equals(serde_json::json!({"a":"b"})));
718        let (sql, params) = Postgres::build(query).unwrap();
719
720        assert_eq!(expected.0, sql);
721        assert_eq!(expected.1, params);
722    }
723
724    #[cfg(feature = "json")]
725    #[test]
726    fn equality_with_a_lhs_json_value() {
727        // A bit artificial, but checks if the ::jsonb casting is done correctly on the right side as well.
728        let expected = expected_values(
729            r#"SELECT "users".* FROM "users" WHERE $1 = "jsonField"::jsonb"#,
730            vec![serde_json::json!({"a": "b"})],
731        );
732
733        let value_expr: Expression = Value::json(serde_json::json!({"a":"b"})).into();
734        let query = Select::from_table("users").so_that(value_expr.equals(Column::from("jsonField")));
735        let (sql, params) = Postgres::build(query).unwrap();
736
737        assert_eq!(expected.0, sql);
738        assert_eq!(expected.1, params);
739    }
740
741    #[cfg(feature = "json")]
742    #[test]
743    fn difference_with_a_json_value() {
744        let expected = expected_values(
745            r#"SELECT "users".* FROM "users" WHERE "jsonField"::jsonb <> $1"#,
746            vec![serde_json::json!({"a": "b"})],
747        );
748
749        let query =
750            Select::from_table("users").so_that(Column::from("jsonField").not_equals(serde_json::json!({"a":"b"})));
751        let (sql, params) = Postgres::build(query).unwrap();
752
753        assert_eq!(expected.0, sql);
754        assert_eq!(expected.1, params);
755    }
756
757    #[cfg(feature = "json")]
758    #[test]
759    fn difference_with_a_lhs_json_value() {
760        let expected = expected_values(
761            r#"SELECT "users".* FROM "users" WHERE $1 <> "jsonField"::jsonb"#,
762            vec![serde_json::json!({"a": "b"})],
763        );
764
765        let value_expr: Expression = Value::json(serde_json::json!({"a":"b"})).into();
766        let query = Select::from_table("users").so_that(value_expr.not_equals(Column::from("jsonField")));
767        let (sql, params) = Postgres::build(query).unwrap();
768
769        assert_eq!(expected.0, sql);
770        assert_eq!(expected.1, params);
771    }
772
773    #[test]
774    fn equality_with_a_xml_value() {
775        let expected = expected_values(
776            r#"SELECT "users".* FROM "users" WHERE "xmlField"::text = $1"#,
777            vec![Value::xml("<salad>wurst</salad>")],
778        );
779
780        let query =
781            Select::from_table("users").so_that(Column::from("xmlField").equals(Value::xml("<salad>wurst</salad>")));
782        let (sql, params) = Postgres::build(query).unwrap();
783
784        assert_eq!(expected.0, sql);
785        assert_eq!(expected.1, params);
786    }
787
788    #[test]
789    fn equality_with_a_lhs_xml_value() {
790        let expected = expected_values(
791            r#"SELECT "users".* FROM "users" WHERE $1 = "xmlField"::text"#,
792            vec![Value::xml("<salad>wurst</salad>")],
793        );
794
795        let value_expr: Expression = Value::xml("<salad>wurst</salad>").into();
796        let query = Select::from_table("users").so_that(value_expr.equals(Column::from("xmlField")));
797        let (sql, params) = Postgres::build(query).unwrap();
798
799        assert_eq!(expected.0, sql);
800        assert_eq!(expected.1, params);
801    }
802
803    #[test]
804    fn difference_with_a_xml_value() {
805        let expected = expected_values(
806            r#"SELECT "users".* FROM "users" WHERE "xmlField"::text <> $1"#,
807            vec![Value::xml("<salad>wurst</salad>")],
808        );
809
810        let query = Select::from_table("users")
811            .so_that(Column::from("xmlField").not_equals(Value::xml("<salad>wurst</salad>")));
812        let (sql, params) = Postgres::build(query).unwrap();
813
814        assert_eq!(expected.0, sql);
815        assert_eq!(expected.1, params);
816    }
817
818    #[test]
819    fn difference_with_a_lhs_xml_value() {
820        let expected = expected_values(
821            r#"SELECT "users".* FROM "users" WHERE $1 <> "xmlField"::text"#,
822            vec![Value::xml("<salad>wurst</salad>")],
823        );
824
825        let value_expr: Expression = Value::xml("<salad>wurst</salad>").into();
826        let query = Select::from_table("users").so_that(value_expr.not_equals(Column::from("xmlField")));
827        let (sql, params) = Postgres::build(query).unwrap();
828
829        assert_eq!(expected.0, sql);
830        assert_eq!(expected.1, params);
831    }
832
833    #[test]
834    fn test_raw_null() {
835        let (sql, params) = Postgres::build(Select::default().value(Value::Text(None).raw())).unwrap();
836        assert_eq!("SELECT null", sql);
837        assert!(params.is_empty());
838    }
839
840    #[test]
841    fn test_raw_int() {
842        let (sql, params) = Postgres::build(Select::default().value(1.raw())).unwrap();
843        assert_eq!("SELECT 1", sql);
844        assert!(params.is_empty());
845    }
846
847    #[test]
848    fn test_raw_real() {
849        let (sql, params) = Postgres::build(Select::default().value(1.3f64.raw())).unwrap();
850        assert_eq!("SELECT 1.3", sql);
851        assert!(params.is_empty());
852    }
853
854    #[test]
855    fn test_raw_text() {
856        let (sql, params) = Postgres::build(Select::default().value("foo".raw())).unwrap();
857        assert_eq!("SELECT 'foo'", sql);
858        assert!(params.is_empty());
859    }
860
861    #[test]
862    fn test_raw_bytes() {
863        let (sql, params) = Postgres::build(Select::default().value(Value::bytes(vec![1, 2, 3]).raw())).unwrap();
864        assert_eq!("SELECT E'010203'", sql);
865        assert!(params.is_empty());
866    }
867
868    #[test]
869    fn test_raw_boolean() {
870        let (sql, params) = Postgres::build(Select::default().value(true.raw())).unwrap();
871        assert_eq!("SELECT true", sql);
872        assert!(params.is_empty());
873
874        let (sql, params) = Postgres::build(Select::default().value(false.raw())).unwrap();
875        assert_eq!("SELECT false", sql);
876        assert!(params.is_empty());
877    }
878
879    #[test]
880    fn test_raw_char() {
881        let (sql, params) = Postgres::build(Select::default().value(Value::character('a').raw())).unwrap();
882        assert_eq!("SELECT 'a'", sql);
883        assert!(params.is_empty());
884    }
885
886    #[test]
887    #[cfg(feature = "json")]
888    fn test_raw_json() {
889        let (sql, params) =
890            Postgres::build(Select::default().value(serde_json::json!({ "foo": "bar" }).raw())).unwrap();
891        assert_eq!("SELECT '{\"foo\":\"bar\"}'", sql);
892        assert!(params.is_empty());
893    }
894
895    #[test]
896    #[cfg(feature = "uuid")]
897    fn test_raw_uuid() {
898        let uuid = uuid::Uuid::new_v4();
899        let (sql, params) = Postgres::build(Select::default().value(uuid.raw())).unwrap();
900
901        assert_eq!(format!("SELECT '{}'", uuid.hyphenated()), sql);
902
903        assert!(params.is_empty());
904    }
905
906    #[test]
907    #[cfg(feature = "chrono")]
908    fn test_raw_datetime() {
909        let dt = chrono::Utc::now();
910        let (sql, params) = Postgres::build(Select::default().value(dt.raw())).unwrap();
911
912        assert_eq!(format!("SELECT '{}'", dt.to_rfc3339(),), sql);
913        assert!(params.is_empty());
914    }
915
916    #[test]
917    fn test_raw_comparator() {
918        let (sql, _) = Postgres::build(Select::from_table("foo").so_that("bar".compare_raw("ILIKE", "baz%"))).unwrap();
919
920        assert_eq!(r#"SELECT "foo".* FROM "foo" WHERE "bar" ILIKE $1"#, sql);
921    }
922
923    #[test]
924    fn test_like_cast_to_string() {
925        let expected = expected_values(r#"SELECT "test".* FROM "test" WHERE "jsonField"::text LIKE $1"#, vec!["%foo%"]);
926
927        let query = Select::from_table("test").so_that(Column::from("jsonField").like("%foo%"));
928        let (sql, params) = Postgres::build(query).unwrap();
929
930        assert_eq!(expected.0, sql);
931        assert_eq!(expected.1, params);
932    }
933
934    #[test]
935    fn test_not_like_cast_to_string() {
936        let expected =
937            expected_values(r#"SELECT "test".* FROM "test" WHERE "jsonField"::text NOT LIKE $1"#, vec!["%foo%"]);
938
939        let query = Select::from_table("test").so_that(Column::from("jsonField").not_like("%foo%"));
940        let (sql, params) = Postgres::build(query).unwrap();
941
942        assert_eq!(expected.0, sql);
943        assert_eq!(expected.1, params);
944    }
945
946    #[test]
947    fn test_begins_with_cast_to_string() {
948        let expected = expected_values(r#"SELECT "test".* FROM "test" WHERE "jsonField"::text LIKE $1"#, vec!["%foo"]);
949
950        let query = Select::from_table("test").so_that(Column::from("jsonField").like("%foo"));
951        let (sql, params) = Postgres::build(query).unwrap();
952
953        assert_eq!(expected.0, sql);
954        assert_eq!(expected.1, params);
955    }
956
957    #[test]
958    fn test_not_begins_with_cast_to_string() {
959        let expected =
960            expected_values(r#"SELECT "test".* FROM "test" WHERE "jsonField"::text NOT LIKE $1"#, vec!["%foo"]);
961
962        let query = Select::from_table("test").so_that(Column::from("jsonField").not_like("%foo"));
963        let (sql, params) = Postgres::build(query).unwrap();
964
965        assert_eq!(expected.0, sql);
966        assert_eq!(expected.1, params);
967    }
968
969    #[test]
970    fn test_ends_with_cast_to_string() {
971        let expected = expected_values(r#"SELECT "test".* FROM "test" WHERE "jsonField"::text LIKE $1"#, vec!["foo%"]);
972
973        let query = Select::from_table("test").so_that(Column::from("jsonField").like("foo%"));
974        let (sql, params) = Postgres::build(query).unwrap();
975
976        assert_eq!(expected.0, sql);
977        assert_eq!(expected.1, params);
978    }
979
980    #[test]
981    fn test_not_ends_with_cast_to_string() {
982        let expected =
983            expected_values(r#"SELECT "test".* FROM "test" WHERE "jsonField"::text NOT LIKE $1"#, vec!["foo%"]);
984
985        let query = Select::from_table("test").so_that(Column::from("jsonField").not_like("foo%"));
986        let (sql, params) = Postgres::build(query).unwrap();
987
988        assert_eq!(expected.0, sql);
989        assert_eq!(expected.1, params);
990    }
991
992    #[test]
993    fn test_default_insert() {
994        let insert = Insert::single_into("foo").value("foo", "bar").value("baz", default_value());
995
996        let (sql, _) = Postgres::build(insert).unwrap();
997
998        assert_eq!("INSERT INTO \"foo\" (\"foo\",\"baz\") VALUES ($1,DEFAULT)", sql);
999    }
1000
1001    #[test]
1002    fn join_is_inserted_positionally() {
1003        let joined_table =
1004            Table::from("User").left_join("Post".alias("p").on(("p", "userId").equals(Column::from(("User", "id")))));
1005        let q = Select::from_table(joined_table).and_from("Toto");
1006        let (sql, _) = Postgres::build(q).unwrap();
1007
1008        assert_eq!("SELECT \"User\".*, \"Toto\".* FROM \"User\" LEFT JOIN \"Post\" AS \"p\" ON \"p\".\"userId\" = \"User\".\"id\", \"Toto\"", sql);
1009    }
1010}