Skip to main content

spawn_db/sql_formatter/
postgres.rs

1//! PostgreSQL-specific SQL escaping for minijinja templates.
2//!
3//! This module provides safe SQL value formatting for PostgreSQL databases.
4//! It handles all minijinja value types and converts them to appropriate
5//! PostgreSQL literal syntax.
6//!
7//! # Escaping Rules
8//!
9//! - **Strings**: Escaped using PostgreSQL's `escape_literal` (handles quotes, special chars)
10//! - **Numbers**: Output directly (integers and floats are safe)
11//! - **Booleans**: Converted to `TRUE` / `FALSE`
12//! - **None**: Converted to `NULL`
13//! - **Undefined**: Empty string (consistent with minijinja defaults)
14//! - **Bytes**: Converted to PostgreSQL bytea hex format (`'\xDEADBEEF'::bytea`)
15//! - **Sequences**: Converted to PostgreSQL `ARRAY[...]` with recursively escaped elements
16//! - **Maps**: Converted to JSON-like string and escaped (can be cast to `::jsonb`)
17//! - **Plain objects**: Stringified and escaped
18//! - **Invalid values**: Return an error
19//!
20//! # Security
21//!
22//! The only ways to bypass escaping are:
23//! - Using the `|safe` filter in templates (intentional)
24//! - Using `Value::from_safe_string()` in Rust code (requires explicit code)
25//!
26//! The `|escape` filter will error for custom SQL formats, preventing accidental misuse.
27
28use minijinja::value::ValueKind;
29use minijinja::{AutoEscape, Output, State, Value};
30use postgres_protocol::escape::escape_literal;
31
32/// The auto-escape format name for PostgreSQL.
33pub const FORMAT_NAME: &str = "sql-postgres";
34
35/// Auto-escape callback for PostgreSQL SQL templates.
36///
37/// Enables SQL escaping for all files.
38pub fn auto_escape_callback(_name: &str) -> AutoEscape {
39    AutoEscape::Custom(FORMAT_NAME)
40}
41
42/// Recursively formats a minijinja Value for safe PostgreSQL interpolation.
43///
44/// This handles all ValueKind variants appropriately for PostgreSQL syntax.
45fn format_value_for_postgres(value: &Value) -> Result<String, minijinja::Error> {
46    match value.kind() {
47        ValueKind::Undefined => Ok(String::new()),
48        ValueKind::None => Ok("NULL".to_string()),
49        ValueKind::Bool => {
50            let b: bool = value.clone().try_into().unwrap_or(false);
51            Ok(if b { "TRUE" } else { "FALSE" }.to_string())
52        }
53        ValueKind::Number => Ok(value.to_string()),
54        ValueKind::String => {
55            let s = value.to_string();
56            Ok(escape_literal(&s))
57        }
58        ValueKind::Bytes => {
59            // Convert to PostgreSQL bytea hex format: '\xDEADBEEF'::bytea
60            if let Some(bytes) = value.as_bytes() {
61                let hex: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
62                Ok(format!("'\\x{}'::bytea", hex))
63            } else {
64                Err(minijinja::Error::new(
65                    minijinja::ErrorKind::InvalidOperation,
66                    "Could not extract bytes from value",
67                ))
68            }
69        }
70        ValueKind::Seq | ValueKind::Iterable => {
71            // Convert to PostgreSQL ARRAY[] syntax with recursively escaped elements
72            // e.g., [1, 'hello', true] becomes ARRAY[1, 'hello', TRUE]
73            let mut elements = Vec::new();
74            for item in value.try_iter().map_err(|e| {
75                minijinja::Error::new(
76                    minijinja::ErrorKind::InvalidOperation,
77                    format!("Could not iterate over sequence: {}", e),
78                )
79            })? {
80                elements.push(format_value_for_postgres(&item)?);
81            }
82            Ok(format!("ARRAY[{}]", elements.join(", ")))
83        }
84        ValueKind::Map => {
85            // Maps don't have a native SQL representation.
86            // Convert to JSON-like string representation and escape it.
87            // Users can cast to ::jsonb if needed: {{ my_map }}::jsonb
88            let s = value.to_string();
89            Ok(escape_literal(&s))
90        }
91        ValueKind::Plain => {
92            // For custom objects, stringify and escape as a string
93            let s = value.to_string();
94            Ok(escape_literal(&s))
95        }
96        ValueKind::Invalid => {
97            // Invalid values contain errors - propagate them
98            Err(minijinja::Error::new(
99                minijinja::ErrorKind::InvalidOperation,
100                format!("Invalid value encountered in SQL template: {}", value),
101            ))
102        }
103        // ValueKind is non-exhaustive, handle any future variants safely
104        _ => {
105            // For unknown types, stringify and escape as a string (safe default)
106            let s = value.to_string();
107            Ok(escape_literal(&s))
108        }
109    }
110}
111
112/// Custom formatter that escapes values for safe PostgreSQL interpolation.
113///
114/// This formatter is invoked when auto-escape is enabled (for .sql templates).
115/// It delegates to `format_value_for_postgres` for type-specific handling.
116///
117/// # Bypass Mechanisms
118///
119/// - Values marked as safe (via `|safe` filter) are output without escaping
120/// - Only applies when `state.auto_escape()` matches our custom format
121pub fn sql_escape_formatter(
122    out: &mut Output<'_>,
123    state: &State<'_, '_>,
124    value: &Value,
125) -> Result<(), minijinja::Error> {
126    // Check if we're in PostgreSQL SQL auto-escape mode
127    if state.auto_escape() == AutoEscape::Custom(FORMAT_NAME) {
128        // If the value is marked as safe (via |safe filter), skip escaping
129        if value.is_safe() {
130            return write!(out, "{}", value).map_err(|e| {
131                minijinja::Error::new(minijinja::ErrorKind::WriteFailure, e.to_string())
132            });
133        }
134
135        let formatted = format_value_for_postgres(value)?;
136        write!(out, "{}", formatted)
137            .map_err(|e| minijinja::Error::new(minijinja::ErrorKind::WriteFailure, e.to_string()))
138    } else {
139        // For non-SQL templates, use default formatting (no escaping)
140        write!(out, "{}", value)
141            .map_err(|e| minijinja::Error::new(minijinja::ErrorKind::WriteFailure, e.to_string()))
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use minijinja::{context, Environment};
149
150    /// Helper to test SQL formatting of a value by rendering it in a .sql template
151    fn render_sql_value(value: Value) -> String {
152        let mut env = Environment::new();
153        env.set_auto_escape_callback(auto_escape_callback);
154        env.set_formatter(sql_escape_formatter);
155        env.add_template("test.sql", "{{ value }}").unwrap();
156        let tmpl = env.get_template("test.sql").unwrap();
157        tmpl.render(context!(value => value)).unwrap()
158    }
159
160    // ===================
161    // String escaping tests
162    // ===================
163
164    #[test]
165    fn test_sql_escape_string() {
166        let result = render_sql_value(Value::from("hello"));
167        assert_eq!(result, "'hello'");
168    }
169
170    #[test]
171    fn test_sql_escape_string_with_quotes() {
172        let result = render_sql_value(Value::from("it's a test"));
173        assert_eq!(result, "'it''s a test'");
174    }
175
176    #[test]
177    fn test_sql_escape_string_injection_attempt() {
178        let result = render_sql_value(Value::from("'; DROP TABLE users; --"));
179        assert_eq!(result, "'''; DROP TABLE users; --'");
180    }
181
182    // ===================
183    // Number tests
184    // ===================
185
186    #[test]
187    fn test_sql_escape_integer() {
188        let result = render_sql_value(Value::from(42));
189        assert_eq!(result, "42");
190    }
191
192    #[test]
193    fn test_sql_escape_negative_integer() {
194        let result = render_sql_value(Value::from(-123));
195        assert_eq!(result, "-123");
196    }
197
198    #[test]
199    fn test_sql_escape_float() {
200        let result = render_sql_value(Value::from(3.14));
201        assert_eq!(result, "3.14");
202    }
203
204    // ===================
205    // Boolean tests
206    // ===================
207
208    #[test]
209    fn test_sql_escape_bool_true() {
210        let result = render_sql_value(Value::from(true));
211        assert_eq!(result, "TRUE");
212    }
213
214    #[test]
215    fn test_sql_escape_bool_false() {
216        let result = render_sql_value(Value::from(false));
217        assert_eq!(result, "FALSE");
218    }
219
220    // ===================
221    // None/Undefined tests
222    // ===================
223
224    #[test]
225    fn test_sql_escape_none() {
226        let result = render_sql_value(Value::from(()));
227        assert_eq!(result, "NULL");
228    }
229
230    #[test]
231    fn test_sql_escape_undefined() {
232        let result = render_sql_value(Value::UNDEFINED);
233        assert_eq!(result, "");
234    }
235
236    // ===================
237    // Sequence/Array tests
238    // ===================
239
240    #[test]
241    fn test_sql_escape_seq_integers() {
242        let result = render_sql_value(Value::from(vec![1, 2, 3]));
243        assert_eq!(result, "ARRAY[1, 2, 3]");
244    }
245
246    #[test]
247    fn test_sql_escape_seq_strings() {
248        let result = render_sql_value(Value::from(vec!["hello", "world"]));
249        assert_eq!(result, "ARRAY['hello', 'world']");
250    }
251
252    #[test]
253    fn test_sql_escape_seq_strings_with_injection() {
254        let result = render_sql_value(Value::from(vec!["safe", "'; DROP TABLE users; --"]));
255        assert_eq!(result, "ARRAY['safe', '''; DROP TABLE users; --']");
256    }
257
258    #[test]
259    fn test_sql_escape_seq_mixed_types() {
260        let values: Vec<Value> = vec![Value::from(1), Value::from("hello"), Value::from(true)];
261        let result = render_sql_value(Value::from(values));
262        assert_eq!(result, "ARRAY[1, 'hello', TRUE]");
263    }
264
265    #[test]
266    fn test_sql_escape_seq_empty() {
267        let empty: Vec<i32> = vec![];
268        let result = render_sql_value(Value::from(empty));
269        assert_eq!(result, "ARRAY[]");
270    }
271
272    #[test]
273    fn test_sql_escape_nested_seq() {
274        let inner1 = Value::from(vec![1, 2]);
275        let inner2 = Value::from(vec![3, 4]);
276        let outer = Value::from(vec![inner1, inner2]);
277        let result = render_sql_value(outer);
278        assert_eq!(result, "ARRAY[ARRAY[1, 2], ARRAY[3, 4]]");
279    }
280
281    // ===================
282    // Map tests
283    // ===================
284
285    #[test]
286    fn test_sql_escape_map() {
287        use std::collections::BTreeMap;
288        let mut map = BTreeMap::new();
289        map.insert("name", "Alice");
290        map.insert("role", "admin");
291        let result = render_sql_value(Value::from(map));
292        // The map will be stringified and escaped as a SQL literal
293        assert!(result.starts_with("'"));
294        assert!(result.ends_with("'"));
295        assert!(result.contains("name"));
296        assert!(result.contains("Alice"));
297    }
298
299    // ===================
300    // Bytes tests
301    // ===================
302
303    #[test]
304    fn test_sql_escape_bytes() {
305        let bytes = Value::from_bytes(vec![0xDE, 0xAD, 0xBE, 0xEF]);
306        let result = render_sql_value(bytes);
307        assert_eq!(result, "'\\xdeadbeef'::bytea");
308    }
309
310    #[test]
311    fn test_sql_escape_bytes_empty() {
312        let bytes = Value::from_bytes(vec![]);
313        let result = render_sql_value(bytes);
314        assert_eq!(result, "'\\x'::bytea");
315    }
316
317    // ===================
318    // Auto-escape behavior tests
319    // ===================
320
321    #[test]
322    fn test_sql_escape_for_non_sql_templates() {
323        let mut env = Environment::new();
324        env.set_auto_escape_callback(auto_escape_callback);
325        env.set_formatter(sql_escape_formatter);
326        // Use .txt extension - should still trigger SQL escaping
327        env.add_template("test.txt", "{{ value }}").unwrap();
328        let tmpl = env.get_template("test.txt").unwrap();
329        let result = tmpl.render(context!(value => "hello")).unwrap();
330        // SQL escaping applies to all files
331        assert_eq!(result, "'hello'");
332    }
333
334    #[test]
335    fn test_sql_safe_filter_bypasses_escaping() {
336        let mut env = Environment::new();
337        env.set_auto_escape_callback(auto_escape_callback);
338        env.set_formatter(sql_escape_formatter);
339        // Using |safe filter should bypass escaping
340        env.add_template("test.sql", "{{ value|safe }}").unwrap();
341        let tmpl = env.get_template("test.sql").unwrap();
342        let result = tmpl.render(context!(value => "raw SQL here")).unwrap();
343        // Should be output as-is without quotes
344        assert_eq!(result, "raw SQL here");
345    }
346
347    #[test]
348    fn test_sql_escape_filter_fails_for_custom_format() {
349        // The |escape filter in minijinja does NOT work with custom formats
350        let mut env = Environment::new();
351        env.set_auto_escape_callback(auto_escape_callback);
352        env.set_formatter(sql_escape_formatter);
353
354        env.add_template("test.sql", "{{ value|escape }}").unwrap();
355        let tmpl = env.get_template("test.sql").unwrap();
356
357        let result = tmpl.render(context!(value => "test"));
358        assert!(result.is_err());
359        assert!(result
360            .unwrap_err()
361            .to_string()
362            .contains("does not know how to format to custom format"));
363    }
364
365    #[test]
366    fn test_sql_escape_from_safe_string_bypasses_escaping() {
367        // Value::from_safe_string() bypasses SQL escaping
368        let mut env = Environment::new();
369        env.set_auto_escape_callback(auto_escape_callback);
370        env.set_formatter(sql_escape_formatter);
371
372        env.add_template("test.sql", "{{ value }}").unwrap();
373        let tmpl = env.get_template("test.sql").unwrap();
374
375        let safe_value = Value::from_safe_string("1 OR 1=1".to_string());
376        let result = tmpl.render(context!(value => safe_value)).unwrap();
377        assert_eq!(result, "1 OR 1=1");
378
379        // Compare with normal string which gets escaped
380        let normal_value = Value::from("1 OR 1=1");
381        let result = tmpl.render(context!(value => normal_value)).unwrap();
382        assert_eq!(result, "'1 OR 1=1'");
383    }
384
385    // ===================
386    // Loop/conditional tests (verify formatter only applies to output)
387    // ===================
388
389    #[test]
390    fn test_sql_escape_only_on_output_not_in_loops() {
391        let mut env = Environment::new();
392        env.set_auto_escape_callback(auto_escape_callback);
393        env.set_formatter(sql_escape_formatter);
394
395        let template =
396            r#"{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}"#;
397        env.add_template("test.sql", template).unwrap();
398        let tmpl = env.get_template("test.sql").unwrap();
399
400        let items = vec!["alice", "bob", "charlie"];
401        let result = tmpl.render(context!(items => items)).unwrap();
402        assert_eq!(result, "'alice', 'bob', 'charlie'");
403    }
404
405    #[test]
406    fn test_sql_escape_only_on_output_not_in_conditionals() {
407        let mut env = Environment::new();
408        env.set_auto_escape_callback(auto_escape_callback);
409        env.set_formatter(sql_escape_formatter);
410
411        let template = r#"{% if enabled %}{{ value }}{% else %}NULL{% endif %}"#;
412        env.add_template("test.sql", template).unwrap();
413        let tmpl = env.get_template("test.sql").unwrap();
414
415        let result = tmpl
416            .render(context!(enabled => true, value => "test"))
417            .unwrap();
418        assert_eq!(result, "'test'");
419
420        let result = tmpl
421            .render(context!(enabled => false, value => "test"))
422            .unwrap();
423        assert_eq!(result, "NULL");
424    }
425
426    #[test]
427    fn test_sql_escape_loop_over_map_keys() {
428        let mut env = Environment::new();
429        env.set_auto_escape_callback(auto_escape_callback);
430        env.set_formatter(sql_escape_formatter);
431
432        let template = r#"{% for key, val in data|items %}{{ key }} = {{ val }}{% if not loop.last %}, {% endif %}{% endfor %}"#;
433        env.add_template("test.sql", template).unwrap();
434        let tmpl = env.get_template("test.sql").unwrap();
435
436        use std::collections::BTreeMap;
437        let mut data = BTreeMap::new();
438        data.insert("name", "Alice");
439        data.insert("role", "admin");
440
441        let result = tmpl.render(context!(data => data)).unwrap();
442        assert_eq!(result, "'name' = 'Alice', 'role' = 'admin'");
443    }
444
445    #[test]
446    fn test_sql_escape_nested_loop() {
447        let mut env = Environment::new();
448        env.set_auto_escape_callback(auto_escape_callback);
449        env.set_formatter(sql_escape_formatter);
450
451        let template = r#"{% for row in rows %}({% for col in row %}{{ col }}{% if not loop.last %}, {% endif %}{% endfor %}){% if not loop.last %}, {% endif %}{% endfor %}"#;
452        env.add_template("test.sql", template).unwrap();
453        let tmpl = env.get_template("test.sql").unwrap();
454
455        let rows: Vec<Vec<&str>> = vec![vec!["a", "b"], vec!["c", "d"]];
456        let result = tmpl.render(context!(rows => rows)).unwrap();
457        assert_eq!(result, "('a', 'b'), ('c', 'd')");
458    }
459
460    #[test]
461    fn test_sql_escape_length_filter_works() {
462        let mut env = Environment::new();
463        env.set_auto_escape_callback(auto_escape_callback);
464        env.set_formatter(sql_escape_formatter);
465
466        let template = r#"{% if items|length > 0 %}{{ items|length }}{% else %}0{% endif %}"#;
467        env.add_template("test.sql", template).unwrap();
468        let tmpl = env.get_template("test.sql").unwrap();
469
470        let items = vec![1, 2, 3];
471        let result = tmpl.render(context!(items => items)).unwrap();
472        assert_eq!(result, "3");
473    }
474
475    // ===================
476    // Auto-escape callback tests
477    // ===================
478
479    #[test]
480    fn test_auto_escape_callback_all_files() {
481        // SQL escaping is enabled for all file types
482        assert_eq!(
483            auto_escape_callback("migration.sql"),
484            AutoEscape::Custom(FORMAT_NAME)
485        );
486        assert_eq!(
487            auto_escape_callback("path/to/file.sql"),
488            AutoEscape::Custom(FORMAT_NAME)
489        );
490        assert_eq!(
491            auto_escape_callback("file.txt"),
492            AutoEscape::Custom(FORMAT_NAME)
493        );
494        assert_eq!(
495            auto_escape_callback("file.html"),
496            AutoEscape::Custom(FORMAT_NAME)
497        );
498    }
499}