Skip to main content

spawn_db/
escape.rs

1//! Type-safe SQL escaping for PostgreSQL.
2//!
3//! This module provides wrapper types that guarantee SQL values have been properly
4//! escaped at construction time. By using these types instead of raw strings,
5//! the type system ensures that escaping cannot be forgotten.
6//!
7//! # Example
8//!
9//! ```
10//! use spawn_db::{sql_query, escape::{EscapedIdentifier, EscapedLiteral}};
11//!
12//! let schema = EscapedIdentifier::new("my_schema");
13//! let value = EscapedLiteral::new("user's input");
14//!
15//! let query = sql_query!(
16//!     "SELECT * FROM {}.users WHERE name = {}",
17//!     schema,
18//!     value
19//! );
20//! ```
21
22use postgres_protocol::escape::{escape_identifier, escape_literal};
23use std::fmt;
24
25/// A trait for types that are safe to interpolate into SQL queries.
26///
27/// Types implementing this trait can be used with the `sql_query!` macro.
28/// The built-in implementations are `EscapedIdentifier`, `EscapedLiteral`,
29/// and `InsecureRawSql`.
30///
31/// You may implement this trait for your own types if you have other
32/// validated/escaped SQL fragments, but do so with caution.
33pub trait SqlSafe {
34    /// Returns the SQL-safe string representation.
35    fn as_sql(&self) -> &str;
36}
37
38impl<S: SqlSafe> SqlSafe for Option<S> {
39    fn as_sql(&self) -> &str {
40        if let Some(inner) = self {
41            return inner.as_sql();
42        }
43        "NULL"
44    }
45}
46
47impl<S: SqlSafe> SqlSafe for &S {
48    fn as_sql(&self) -> &str {
49        (*self).as_sql()
50    }
51}
52
53/// A PostgreSQL identifier (schema, table, column name) that has been safely escaped.
54///
55/// The value is escaped at construction time using PostgreSQL's `quote_ident` rules:
56/// - Wrapped in double quotes
57/// - Any embedded double quotes are doubled
58///
59/// # Example
60///
61/// ```
62/// use spawn_db::escape::EscapedIdentifier;
63///
64/// let schema = EscapedIdentifier::new("my_schema");
65/// assert_eq!(schema.as_str(), "\"my_schema\"");
66///
67/// let tricky = EscapedIdentifier::new("schema\"name");
68/// assert_eq!(tricky.as_str(), "\"schema\"\"name\"");
69/// ```
70#[derive(Debug, Clone, PartialEq, Eq, Hash)]
71pub struct EscapedIdentifier {
72    raw: String,
73    escaped: String,
74}
75
76impl EscapedIdentifier {
77    /// Creates a new escaped identifier from a raw string.
78    ///
79    /// The input is immediately escaped using PostgreSQL's identifier escaping rules.
80    pub fn new(raw: &str) -> Self {
81        Self {
82            raw: raw.to_string(),
83            escaped: escape_identifier(raw),
84        }
85    }
86
87    /// Returns the escaped identifier as a string slice.
88    ///
89    /// This value is safe to interpolate directly into SQL queries.
90    pub fn as_str(&self) -> &str {
91        &self.escaped
92    }
93
94    /// Returns the original unescaped value.
95    ///
96    /// Useful for error messages or logging where the raw value is needed.
97    pub fn raw_value(&self) -> &str {
98        &self.raw
99    }
100}
101
102impl fmt::Display for EscapedIdentifier {
103    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104        write!(f, "{}", self.escaped)
105    }
106}
107
108impl SqlSafe for EscapedIdentifier {
109    fn as_sql(&self) -> &str {
110        self.as_str()
111    }
112}
113
114/// A PostgreSQL string literal that has been safely escaped.
115///
116/// The value is escaped at construction time using PostgreSQL's `quote_literal` rules:
117/// - Wrapped in single quotes
118/// - Any embedded single quotes are doubled
119/// - If backslashes are present, prefixed with `E` and backslashes are doubled
120///
121/// # Example
122///
123/// ```
124/// use spawn_db::escape::EscapedLiteral;
125///
126/// let value = EscapedLiteral::new("hello");
127/// assert_eq!(value.as_str(), "'hello'");
128///
129/// let quoted = EscapedLiteral::new("it's");
130/// assert_eq!(quoted.as_str(), "'it''s'");
131/// ```
132#[derive(Debug, Clone, PartialEq, Eq, Hash)]
133pub struct EscapedLiteral {
134    raw: String,
135    escaped: String,
136}
137
138impl EscapedLiteral {
139    /// Creates a new escaped literal from a raw string.
140    ///
141    /// The input is immediately escaped using PostgreSQL's literal escaping rules.
142    pub fn new(raw: &str) -> Self {
143        Self {
144            raw: raw.to_string(),
145            escaped: escape_literal(raw),
146        }
147    }
148
149    /// Returns the escaped literal as a string slice.
150    ///
151    /// This value is safe to interpolate directly into SQL queries.
152    pub fn as_str(&self) -> &str {
153        &self.escaped
154    }
155
156    /// Returns the original unescaped value.
157    ///
158    /// Useful for error messages or logging where the raw value is needed.
159    pub fn raw_value(&self) -> &str {
160        &self.raw
161    }
162}
163
164impl fmt::Display for EscapedLiteral {
165    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166        write!(f, "{}", self.escaped)
167    }
168}
169
170impl SqlSafe for EscapedLiteral {
171    fn as_sql(&self) -> &str {
172        self.as_str()
173    }
174}
175
176/// Raw SQL that has not been escaped.
177///
178/// This type is for cases where you genuinely need to include raw SQL that cannot
179/// be escaped, such as SQL keywords, operators, or pre-validated static strings.
180///
181/// # Warning
182///
183/// Use this type with extreme caution. It bypasses all escaping protections.
184/// Only use it for:
185/// - Static SQL fragments known at compile time
186/// - SQL that has been validated through other means
187///
188/// # Example
189///
190/// ```
191/// use spawn_db::{sql_query, escape::{EscapedIdentifier, InsecureRawSql}};
192///
193/// let schema = EscapedIdentifier::new("my_schema");
194/// let order = InsecureRawSql::new("ORDER BY created_at DESC");
195///
196/// let query = sql_query!(
197///     "SELECT * FROM {}.users {}",
198///     schema,
199///     order
200/// );
201/// ```
202#[derive(Debug, Clone, PartialEq, Eq, Hash)]
203pub struct InsecureRawSql(String);
204
205impl InsecureRawSql {
206    /// Creates a new raw SQL fragment.
207    ///
208    /// # Warning
209    ///
210    /// This does NOT escape the input. Only use this for SQL that you have
211    /// verified is safe, such as static strings or validated input.
212    pub fn new(raw: &str) -> Self {
213        Self(raw.to_string())
214    }
215
216    /// Returns the raw SQL as a string slice.
217    pub fn as_str(&self) -> &str {
218        &self.0
219    }
220}
221
222impl fmt::Display for InsecureRawSql {
223    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224        write!(f, "{}", self.0)
225    }
226}
227
228impl SqlSafe for InsecureRawSql {
229    fn as_sql(&self) -> &str {
230        self.as_str()
231    }
232}
233
234/// A complete SQL query that has been constructed using only safe components.
235///
236/// This type can only be created through the `sql_query!` macro, which ensures
237/// that all interpolated values implement `SqlSafe`.
238///
239/// # Example
240///
241/// ```
242/// use spawn_db::{sql_query, escape::{EscapedIdentifier, EscapedLiteral}};
243///
244/// let schema = EscapedIdentifier::new("public");
245/// let name = EscapedLiteral::new("Alice");
246///
247/// let query = sql_query!(
248///     "SELECT * FROM {}.users WHERE name = {}",
249///     schema,
250///     name
251/// );
252///
253/// assert!(query.as_str().contains("\"public\""));
254/// assert!(query.as_str().contains("'Alice'"));
255/// ```
256#[derive(Debug, Clone, PartialEq, Eq, Hash)]
257pub struct EscapedQuery(String);
258
259impl EscapedQuery {
260    /// Creates a new EscapedQuery.
261    ///
262    /// This is intentionally private to the crate. Use the `sql_query!` macro instead.
263    #[doc(hidden)]
264    pub fn __new_from_macro(sql: String) -> Self {
265        Self(sql)
266    }
267
268    /// Returns the query as a string slice.
269    pub fn as_str(&self) -> &str {
270        &self.0
271    }
272}
273
274impl fmt::Display for EscapedQuery {
275    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276        write!(f, "{}", self.0)
277    }
278}
279
280/// Creates an `EscapedQuery` from a format string and SQL-safe arguments.
281///
282/// This macro works like `format!`, but only accepts arguments that implement
283/// the `SqlSafe` trait. This ensures that all interpolated values have been
284/// properly escaped.
285///
286/// # Accepted Types
287///
288/// - `EscapedIdentifier` - for schema, table, and column names
289/// - `EscapedLiteral` - for string values
290/// - `InsecureRawSql` - for raw SQL (use with caution)
291///
292/// # Example
293///
294/// ```
295/// use spawn_db::{sql_query, escape::{EscapedIdentifier, EscapedLiteral}};
296///
297/// let schema = EscapedIdentifier::new("my_schema");
298/// let table = EscapedIdentifier::new("users");
299/// let name = EscapedLiteral::new("O'Brien");
300///
301/// let query = sql_query!(
302///     "SELECT * FROM {}.{} WHERE name = {}",
303///     schema,
304///     table,
305///     name
306/// );
307/// ```
308///
309/// # Compile-Time Safety
310///
311/// Passing a raw `String` or `&str` will result in a compile error:
312///
313/// ```compile_fail
314/// use spawn_db::sql_query;
315///
316/// let unsafe_input = "Robert'; DROP TABLE users; --";
317/// let query = sql_query!("SELECT * FROM users WHERE name = {}", unsafe_input);
318/// // Error: the trait bound `&str: SqlSafe` is not satisfied
319/// ```
320#[macro_export]
321macro_rules! sql_query {
322    ($fmt:literal $(, $arg:expr)* $(,)?) => {{
323        $crate::escape::EscapedQuery::__new_from_macro(
324            format!($fmt $(, $crate::escape::SqlSafe::as_sql(&$arg))*)
325        )
326    }};
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[test]
334    fn test_escaped_identifier_basic() {
335        let ident = EscapedIdentifier::new("my_schema");
336        assert_eq!(ident.as_str(), "\"my_schema\"");
337    }
338
339    #[test]
340    fn test_escaped_literal_basic() {
341        let lit = EscapedLiteral::new("hello");
342        assert_eq!(lit.as_str(), "'hello'");
343    }
344
345    #[test]
346    fn test_sql_query_macro() {
347        let schema = EscapedIdentifier::new("public");
348        let name = EscapedLiteral::new("Alice");
349
350        let query = sql_query!("SELECT * FROM {}.users WHERE name = {}", schema, name);
351
352        assert_eq!(
353            query.as_str(),
354            "SELECT * FROM \"public\".users WHERE name = 'Alice'"
355        );
356    }
357
358    #[test]
359    fn test_sql_query_with_insecure_raw() {
360        let schema = EscapedIdentifier::new("public");
361        let order = InsecureRawSql::new("ORDER BY id DESC");
362
363        let query = sql_query!("SELECT * FROM {}.users {}", schema, order);
364
365        assert_eq!(
366            query.as_str(),
367            "SELECT * FROM \"public\".users ORDER BY id DESC"
368        );
369    }
370
371    #[test]
372    fn test_sql_query_escapes_injection_attempt() {
373        let malicious = EscapedLiteral::new("'; DROP TABLE users; --");
374
375        let query = sql_query!("SELECT * FROM users WHERE name = {}", malicious);
376
377        // The quote is doubled, making the malicious input a safe string literal
378        assert_eq!(
379            query.as_str(),
380            "SELECT * FROM users WHERE name = '''; DROP TABLE users; --'"
381        );
382    }
383}