spawn_db/escape.rs
1//! Type-safe SQL escaping for PostgreSQL.
2//!
3//! This module provides wrapper types that guarantee SQL values have been properly
4//! escaped at construction time. By using these types instead of raw strings,
5//! the type system ensures that escaping cannot be forgotten.
6//!
7//! # Example
8//!
9//! ```
10//! use spawn_db::{sql_query, escape::{EscapedIdentifier, EscapedLiteral}};
11//!
12//! let schema = EscapedIdentifier::new("my_schema");
13//! let value = EscapedLiteral::new("user's input");
14//!
15//! let query = sql_query!(
16//! "SELECT * FROM {}.users WHERE name = {}",
17//! schema,
18//! value
19//! );
20//! ```
21
22use postgres_protocol::escape::{escape_identifier, escape_literal};
23use std::fmt;
24
25/// A trait for types that are safe to interpolate into SQL queries.
26///
27/// Types implementing this trait can be used with the `sql_query!` macro.
28/// The built-in implementations are `EscapedIdentifier`, `EscapedLiteral`,
29/// and `InsecureRawSql`.
30///
31/// You may implement this trait for your own types if you have other
32/// validated/escaped SQL fragments, but do so with caution.
33pub trait SqlSafe {
34 /// Returns the SQL-safe string representation.
35 fn as_sql(&self) -> &str;
36}
37
38impl<S: SqlSafe> SqlSafe for Option<S> {
39 fn as_sql(&self) -> &str {
40 if let Some(inner) = self {
41 return inner.as_sql();
42 }
43 "NULL"
44 }
45}
46
47impl<S: SqlSafe> SqlSafe for &S {
48 fn as_sql(&self) -> &str {
49 (*self).as_sql()
50 }
51}
52
53/// A PostgreSQL identifier (schema, table, column name) that has been safely escaped.
54///
55/// The value is escaped at construction time using PostgreSQL's `quote_ident` rules:
56/// - Wrapped in double quotes
57/// - Any embedded double quotes are doubled
58///
59/// # Example
60///
61/// ```
62/// use spawn_db::escape::EscapedIdentifier;
63///
64/// let schema = EscapedIdentifier::new("my_schema");
65/// assert_eq!(schema.as_str(), "\"my_schema\"");
66///
67/// let tricky = EscapedIdentifier::new("schema\"name");
68/// assert_eq!(tricky.as_str(), "\"schema\"\"name\"");
69/// ```
70#[derive(Debug, Clone, PartialEq, Eq, Hash)]
71pub struct EscapedIdentifier {
72 raw: String,
73 escaped: String,
74}
75
76impl EscapedIdentifier {
77 /// Creates a new escaped identifier from a raw string.
78 ///
79 /// The input is immediately escaped using PostgreSQL's identifier escaping rules.
80 pub fn new(raw: &str) -> Self {
81 Self {
82 raw: raw.to_string(),
83 escaped: escape_identifier(raw),
84 }
85 }
86
87 /// Returns the escaped identifier as a string slice.
88 ///
89 /// This value is safe to interpolate directly into SQL queries.
90 pub fn as_str(&self) -> &str {
91 &self.escaped
92 }
93
94 /// Returns the original unescaped value.
95 ///
96 /// Useful for error messages or logging where the raw value is needed.
97 pub fn raw_value(&self) -> &str {
98 &self.raw
99 }
100}
101
102impl fmt::Display for EscapedIdentifier {
103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 write!(f, "{}", self.escaped)
105 }
106}
107
108impl SqlSafe for EscapedIdentifier {
109 fn as_sql(&self) -> &str {
110 self.as_str()
111 }
112}
113
114/// A PostgreSQL string literal that has been safely escaped.
115///
116/// The value is escaped at construction time using PostgreSQL's `quote_literal` rules:
117/// - Wrapped in single quotes
118/// - Any embedded single quotes are doubled
119/// - If backslashes are present, prefixed with `E` and backslashes are doubled
120///
121/// # Example
122///
123/// ```
124/// use spawn_db::escape::EscapedLiteral;
125///
126/// let value = EscapedLiteral::new("hello");
127/// assert_eq!(value.as_str(), "'hello'");
128///
129/// let quoted = EscapedLiteral::new("it's");
130/// assert_eq!(quoted.as_str(), "'it''s'");
131/// ```
132#[derive(Debug, Clone, PartialEq, Eq, Hash)]
133pub struct EscapedLiteral {
134 raw: String,
135 escaped: String,
136}
137
138impl EscapedLiteral {
139 /// Creates a new escaped literal from a raw string.
140 ///
141 /// The input is immediately escaped using PostgreSQL's literal escaping rules.
142 pub fn new(raw: &str) -> Self {
143 Self {
144 raw: raw.to_string(),
145 escaped: escape_literal(raw),
146 }
147 }
148
149 /// Returns the escaped literal as a string slice.
150 ///
151 /// This value is safe to interpolate directly into SQL queries.
152 pub fn as_str(&self) -> &str {
153 &self.escaped
154 }
155
156 /// Returns the original unescaped value.
157 ///
158 /// Useful for error messages or logging where the raw value is needed.
159 pub fn raw_value(&self) -> &str {
160 &self.raw
161 }
162}
163
164impl fmt::Display for EscapedLiteral {
165 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166 write!(f, "{}", self.escaped)
167 }
168}
169
170impl SqlSafe for EscapedLiteral {
171 fn as_sql(&self) -> &str {
172 self.as_str()
173 }
174}
175
176/// Raw SQL that has not been escaped.
177///
178/// This type is for cases where you genuinely need to include raw SQL that cannot
179/// be escaped, such as SQL keywords, operators, or pre-validated static strings.
180///
181/// # Warning
182///
183/// Use this type with extreme caution. It bypasses all escaping protections.
184/// Only use it for:
185/// - Static SQL fragments known at compile time
186/// - SQL that has been validated through other means
187///
188/// # Example
189///
190/// ```
191/// use spawn_db::{sql_query, escape::{EscapedIdentifier, InsecureRawSql}};
192///
193/// let schema = EscapedIdentifier::new("my_schema");
194/// let order = InsecureRawSql::new("ORDER BY created_at DESC");
195///
196/// let query = sql_query!(
197/// "SELECT * FROM {}.users {}",
198/// schema,
199/// order
200/// );
201/// ```
202#[derive(Debug, Clone, PartialEq, Eq, Hash)]
203pub struct InsecureRawSql(String);
204
205impl InsecureRawSql {
206 /// Creates a new raw SQL fragment.
207 ///
208 /// # Warning
209 ///
210 /// This does NOT escape the input. Only use this for SQL that you have
211 /// verified is safe, such as static strings or validated input.
212 pub fn new(raw: &str) -> Self {
213 Self(raw.to_string())
214 }
215
216 /// Returns the raw SQL as a string slice.
217 pub fn as_str(&self) -> &str {
218 &self.0
219 }
220}
221
222impl fmt::Display for InsecureRawSql {
223 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224 write!(f, "{}", self.0)
225 }
226}
227
228impl SqlSafe for InsecureRawSql {
229 fn as_sql(&self) -> &str {
230 self.as_str()
231 }
232}
233
234/// A complete SQL query that has been constructed using only safe components.
235///
236/// This type can only be created through the `sql_query!` macro, which ensures
237/// that all interpolated values implement `SqlSafe`.
238///
239/// # Example
240///
241/// ```
242/// use spawn_db::{sql_query, escape::{EscapedIdentifier, EscapedLiteral}};
243///
244/// let schema = EscapedIdentifier::new("public");
245/// let name = EscapedLiteral::new("Alice");
246///
247/// let query = sql_query!(
248/// "SELECT * FROM {}.users WHERE name = {}",
249/// schema,
250/// name
251/// );
252///
253/// assert!(query.as_str().contains("\"public\""));
254/// assert!(query.as_str().contains("'Alice'"));
255/// ```
256#[derive(Debug, Clone, PartialEq, Eq, Hash)]
257pub struct EscapedQuery(String);
258
259impl EscapedQuery {
260 /// Creates a new EscapedQuery.
261 ///
262 /// This is intentionally private to the crate. Use the `sql_query!` macro instead.
263 #[doc(hidden)]
264 pub fn __new_from_macro(sql: String) -> Self {
265 Self(sql)
266 }
267
268 /// Returns the query as a string slice.
269 pub fn as_str(&self) -> &str {
270 &self.0
271 }
272}
273
274impl fmt::Display for EscapedQuery {
275 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276 write!(f, "{}", self.0)
277 }
278}
279
280/// Creates an `EscapedQuery` from a format string and SQL-safe arguments.
281///
282/// This macro works like `format!`, but only accepts arguments that implement
283/// the `SqlSafe` trait. This ensures that all interpolated values have been
284/// properly escaped.
285///
286/// # Accepted Types
287///
288/// - `EscapedIdentifier` - for schema, table, and column names
289/// - `EscapedLiteral` - for string values
290/// - `InsecureRawSql` - for raw SQL (use with caution)
291///
292/// # Example
293///
294/// ```
295/// use spawn_db::{sql_query, escape::{EscapedIdentifier, EscapedLiteral}};
296///
297/// let schema = EscapedIdentifier::new("my_schema");
298/// let table = EscapedIdentifier::new("users");
299/// let name = EscapedLiteral::new("O'Brien");
300///
301/// let query = sql_query!(
302/// "SELECT * FROM {}.{} WHERE name = {}",
303/// schema,
304/// table,
305/// name
306/// );
307/// ```
308///
309/// # Compile-Time Safety
310///
311/// Passing a raw `String` or `&str` will result in a compile error:
312///
313/// ```compile_fail
314/// use spawn_db::sql_query;
315///
316/// let unsafe_input = "Robert'; DROP TABLE users; --";
317/// let query = sql_query!("SELECT * FROM users WHERE name = {}", unsafe_input);
318/// // Error: the trait bound `&str: SqlSafe` is not satisfied
319/// ```
320#[macro_export]
321macro_rules! sql_query {
322 ($fmt:literal $(, $arg:expr)* $(,)?) => {{
323 $crate::escape::EscapedQuery::__new_from_macro(
324 format!($fmt $(, $crate::escape::SqlSafe::as_sql(&$arg))*)
325 )
326 }};
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[test]
334 fn test_escaped_identifier_basic() {
335 let ident = EscapedIdentifier::new("my_schema");
336 assert_eq!(ident.as_str(), "\"my_schema\"");
337 }
338
339 #[test]
340 fn test_escaped_literal_basic() {
341 let lit = EscapedLiteral::new("hello");
342 assert_eq!(lit.as_str(), "'hello'");
343 }
344
345 #[test]
346 fn test_sql_query_macro() {
347 let schema = EscapedIdentifier::new("public");
348 let name = EscapedLiteral::new("Alice");
349
350 let query = sql_query!("SELECT * FROM {}.users WHERE name = {}", schema, name);
351
352 assert_eq!(
353 query.as_str(),
354 "SELECT * FROM \"public\".users WHERE name = 'Alice'"
355 );
356 }
357
358 #[test]
359 fn test_sql_query_with_insecure_raw() {
360 let schema = EscapedIdentifier::new("public");
361 let order = InsecureRawSql::new("ORDER BY id DESC");
362
363 let query = sql_query!("SELECT * FROM {}.users {}", schema, order);
364
365 assert_eq!(
366 query.as_str(),
367 "SELECT * FROM \"public\".users ORDER BY id DESC"
368 );
369 }
370
371 #[test]
372 fn test_sql_query_escapes_injection_attempt() {
373 let malicious = EscapedLiteral::new("'; DROP TABLE users; --");
374
375 let query = sql_query!("SELECT * FROM users WHERE name = {}", malicious);
376
377 // The quote is doubled, making the malicious input a safe string literal
378 assert_eq!(
379 query.as_str(),
380 "SELECT * FROM users WHERE name = '''; DROP TABLE users; --'"
381 );
382 }
383}