sea_query/backend/postgres/
mod.rs

1pub(crate) mod extension;
2pub(crate) mod foreign_key;
3pub(crate) mod index;
4pub(crate) mod query;
5pub(crate) mod table;
6pub(crate) mod types;
7
8use super::*;
9
10/// Postgres query builder.
11#[derive(Default, Debug)]
12pub struct PostgresQueryBuilder;
13
14const QUOTE: Quote = Quote(b'"', b'"');
15
16impl GenericBuilder for PostgresQueryBuilder {}
17
18impl SchemaBuilder for PostgresQueryBuilder {}
19
20impl QuotedBuilder for PostgresQueryBuilder {
21    fn quote(&self) -> Quote {
22        QUOTE
23    }
24}
25
26// https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-BACKSLASH-TABLE
27impl EscapeBuilder for PostgresQueryBuilder {
28    fn needs_escape(&self, s: &str) -> bool {
29        s.chars().any(|c| match c {
30            '\x08' | '\x0C' | '\n' | '\r' | '\t' | '\\' | '\'' | '\0' => true,
31            c if c.is_ascii_control() => true,
32            _ => false,
33        })
34    }
35
36    fn write_escaped(&self, buffer: &mut impl Write, string: &str) {
37        for c in string.chars() {
38            match c {
39                '\x08' => buffer.write_str(r#"\b"#),
40                '\x0C' => buffer.write_str(r#"\f"#),
41                '\n' => buffer.write_str(r"\n"),
42                '\r' => buffer.write_str(r"\r"),
43                '\t' => buffer.write_str(r"\t"),
44                '\\' => buffer.write_str(r#"\\"#),
45                '\'' => buffer.write_str(r#"\'"#),
46                '\0' => buffer.write_str(r#"\0"#),
47                c if c.is_ascii_control() => write!(buffer, "\\{:03o}", c as u32),
48                _ => buffer.write_char(c),
49            }
50            .unwrap();
51        }
52    }
53
54    fn unescape_string(&self, string: &str) -> String {
55        let mut chars = string.chars().peekable();
56        let mut result = String::with_capacity(string.len());
57
58        while let Some(c) = chars.next() {
59            if c != '\\' {
60                result.push(c);
61                continue;
62            }
63
64            let Some(next) = chars.next() else {
65                result.push('\\');
66                continue;
67            };
68
69            match next {
70                'b' => result.push('\x08'),
71                'f' => result.push('\x0C'),
72                'n' => result.push('\n'),
73                'r' => result.push('\r'),
74                't' => result.push('\t'),
75                '0' => result.push('\0'),
76                '\'' => result.push('\''),
77                '\\' => result.push('\\'),
78                'u' => {
79                    let mut hex = String::new();
80                    for _ in 0..4 {
81                        if let Some(h) = chars.next() {
82                            hex.push(h);
83                        }
84                    }
85                    if let Ok(code) = u32::from_str_radix(&hex, 16) {
86                        if let Some(ch) = std::char::from_u32(code) {
87                            result.push(ch);
88                        }
89                    }
90                }
91                'U' => {
92                    let mut hex = String::new();
93                    for _ in 0..8 {
94                        if let Some(h) = chars.next() {
95                            hex.push(h);
96                        }
97                    }
98                    if let Ok(code) = u32::from_str_radix(&hex, 16) {
99                        if let Some(ch) = std::char::from_u32(code) {
100                            result.push(ch);
101                        }
102                    }
103                }
104                c @ '0'..='7' => {
105                    let mut oct = String::new();
106                    oct.push(c);
107                    for _ in 0..2 {
108                        if let Some(next_o) = chars.peek() {
109                            if ('0'..='7').contains(next_o) {
110                                oct.push(chars.next().unwrap());
111                            }
112                        }
113                    }
114                    if let Ok(val) = u8::from_str_radix(&oct, 8) {
115                        result.push(val as char);
116                    }
117                }
118                other => {
119                    result.push('\\');
120                    result.push(other);
121                }
122            }
123        }
124
125        result
126    }
127}
128
129impl TableRefBuilder for PostgresQueryBuilder {}
130
131#[cfg(test)]
132mod tests {
133    use crate::{EscapeBuilder, PostgresQueryBuilder};
134
135    #[test]
136    fn test_write_escaped() {
137        let escaper = PostgresQueryBuilder;
138
139        let control_chars: String = (0u8..=31).map(|b| b as char).collect();
140
141        let escaped = escaper.escape_string(&control_chars);
142
143        assert!(escaped.contains(r"\b")); // 0x08
144        assert!(escaped.contains(r"\f")); // 0x0C
145        assert!(escaped.contains(r"\n")); // 0x0A
146        assert!(escaped.contains(r"\r")); // 0x0D
147        assert!(escaped.contains(r"\t")); // 0x09
148        assert!(escaped.contains(r"\0")); // 0x00
149
150        for b in 0u8..=31 {
151            let c = b as char;
152            if !matches!(c, '\x00' | '\x08' | '\x09' | '\x0A' | '\x0C' | '\x0D') {
153                let octal = format!("\\{:03o}", b);
154                assert!(escaped.contains(&octal));
155            }
156        }
157    }
158
159    #[test]
160    fn test_unescape_string() {
161        let escaper = PostgresQueryBuilder;
162
163        let escaped = r"\b\f\n\r\t\0\'\\\101\102\103\u4F60\U0001F600";
164        let unescaped = escaper.unescape_string(escaped);
165
166        let expected = "\x08\x0C\n\r\t\0'\\ABC你😀";
167
168        assert_eq!(unescaped, expected);
169
170        let escaped_expected = escaper.escape_string(expected);
171
172        // We don't convert ASCII chars back to octal in escaping
173        assert_eq!(r"\b\f\n\r\t\0\'\\ABC你😀", escaped_expected);
174    }
175}