1use minijinja::value::ValueKind;
29use minijinja::{AutoEscape, Output, State, Value};
30use postgres_protocol::escape::escape_literal;
31
32pub const FORMAT_NAME: &str = "sql-postgres";
34
35pub fn auto_escape_callback(_name: &str) -> AutoEscape {
39 AutoEscape::Custom(FORMAT_NAME)
40}
41
42fn format_value_for_postgres(value: &Value) -> Result<String, minijinja::Error> {
46 match value.kind() {
47 ValueKind::Undefined => Ok(String::new()),
48 ValueKind::None => Ok("NULL".to_string()),
49 ValueKind::Bool => {
50 let b: bool = value.clone().try_into().unwrap_or(false);
51 Ok(if b { "TRUE" } else { "FALSE" }.to_string())
52 }
53 ValueKind::Number => Ok(value.to_string()),
54 ValueKind::String => {
55 let s = value.to_string();
56 Ok(escape_literal(&s))
57 }
58 ValueKind::Bytes => {
59 if let Some(bytes) = value.as_bytes() {
61 let hex: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
62 Ok(format!("'\\x{}'::bytea", hex))
63 } else {
64 Err(minijinja::Error::new(
65 minijinja::ErrorKind::InvalidOperation,
66 "Could not extract bytes from value",
67 ))
68 }
69 }
70 ValueKind::Seq | ValueKind::Iterable => {
71 let mut elements = Vec::new();
74 for item in value.try_iter().map_err(|e| {
75 minijinja::Error::new(
76 minijinja::ErrorKind::InvalidOperation,
77 format!("Could not iterate over sequence: {}", e),
78 )
79 })? {
80 elements.push(format_value_for_postgres(&item)?);
81 }
82 Ok(format!("ARRAY[{}]", elements.join(", ")))
83 }
84 ValueKind::Map => {
85 let s = value.to_string();
89 Ok(escape_literal(&s))
90 }
91 ValueKind::Plain => {
92 let s = value.to_string();
94 Ok(escape_literal(&s))
95 }
96 ValueKind::Invalid => {
97 Err(minijinja::Error::new(
99 minijinja::ErrorKind::InvalidOperation,
100 format!("Invalid value encountered in SQL template: {}", value),
101 ))
102 }
103 _ => {
105 let s = value.to_string();
107 Ok(escape_literal(&s))
108 }
109 }
110}
111
112pub fn sql_escape_formatter(
122 out: &mut Output<'_>,
123 state: &State<'_, '_>,
124 value: &Value,
125) -> Result<(), minijinja::Error> {
126 if state.auto_escape() == AutoEscape::Custom(FORMAT_NAME) {
128 if value.is_safe() {
130 return write!(out, "{}", value).map_err(|e| {
131 minijinja::Error::new(minijinja::ErrorKind::WriteFailure, e.to_string())
132 });
133 }
134
135 let formatted = format_value_for_postgres(value)?;
136 write!(out, "{}", formatted)
137 .map_err(|e| minijinja::Error::new(minijinja::ErrorKind::WriteFailure, e.to_string()))
138 } else {
139 write!(out, "{}", value)
141 .map_err(|e| minijinja::Error::new(minijinja::ErrorKind::WriteFailure, e.to_string()))
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use minijinja::{context, Environment};
149
150 fn render_sql_value(value: Value) -> String {
152 let mut env = Environment::new();
153 env.set_auto_escape_callback(auto_escape_callback);
154 env.set_formatter(sql_escape_formatter);
155 env.add_template("test.sql", "{{ value }}").unwrap();
156 let tmpl = env.get_template("test.sql").unwrap();
157 tmpl.render(context!(value => value)).unwrap()
158 }
159
160 #[test]
165 fn test_sql_escape_string() {
166 let result = render_sql_value(Value::from("hello"));
167 assert_eq!(result, "'hello'");
168 }
169
170 #[test]
171 fn test_sql_escape_string_with_quotes() {
172 let result = render_sql_value(Value::from("it's a test"));
173 assert_eq!(result, "'it''s a test'");
174 }
175
176 #[test]
177 fn test_sql_escape_string_injection_attempt() {
178 let result = render_sql_value(Value::from("'; DROP TABLE users; --"));
179 assert_eq!(result, "'''; DROP TABLE users; --'");
180 }
181
182 #[test]
187 fn test_sql_escape_integer() {
188 let result = render_sql_value(Value::from(42));
189 assert_eq!(result, "42");
190 }
191
192 #[test]
193 fn test_sql_escape_negative_integer() {
194 let result = render_sql_value(Value::from(-123));
195 assert_eq!(result, "-123");
196 }
197
198 #[test]
199 fn test_sql_escape_float() {
200 let result = render_sql_value(Value::from(3.14));
201 assert_eq!(result, "3.14");
202 }
203
204 #[test]
209 fn test_sql_escape_bool_true() {
210 let result = render_sql_value(Value::from(true));
211 assert_eq!(result, "TRUE");
212 }
213
214 #[test]
215 fn test_sql_escape_bool_false() {
216 let result = render_sql_value(Value::from(false));
217 assert_eq!(result, "FALSE");
218 }
219
220 #[test]
225 fn test_sql_escape_none() {
226 let result = render_sql_value(Value::from(()));
227 assert_eq!(result, "NULL");
228 }
229
230 #[test]
231 fn test_sql_escape_undefined() {
232 let result = render_sql_value(Value::UNDEFINED);
233 assert_eq!(result, "");
234 }
235
236 #[test]
241 fn test_sql_escape_seq_integers() {
242 let result = render_sql_value(Value::from(vec![1, 2, 3]));
243 assert_eq!(result, "ARRAY[1, 2, 3]");
244 }
245
246 #[test]
247 fn test_sql_escape_seq_strings() {
248 let result = render_sql_value(Value::from(vec!["hello", "world"]));
249 assert_eq!(result, "ARRAY['hello', 'world']");
250 }
251
252 #[test]
253 fn test_sql_escape_seq_strings_with_injection() {
254 let result = render_sql_value(Value::from(vec!["safe", "'; DROP TABLE users; --"]));
255 assert_eq!(result, "ARRAY['safe', '''; DROP TABLE users; --']");
256 }
257
258 #[test]
259 fn test_sql_escape_seq_mixed_types() {
260 let values: Vec<Value> = vec![Value::from(1), Value::from("hello"), Value::from(true)];
261 let result = render_sql_value(Value::from(values));
262 assert_eq!(result, "ARRAY[1, 'hello', TRUE]");
263 }
264
265 #[test]
266 fn test_sql_escape_seq_empty() {
267 let empty: Vec<i32> = vec![];
268 let result = render_sql_value(Value::from(empty));
269 assert_eq!(result, "ARRAY[]");
270 }
271
272 #[test]
273 fn test_sql_escape_nested_seq() {
274 let inner1 = Value::from(vec![1, 2]);
275 let inner2 = Value::from(vec![3, 4]);
276 let outer = Value::from(vec![inner1, inner2]);
277 let result = render_sql_value(outer);
278 assert_eq!(result, "ARRAY[ARRAY[1, 2], ARRAY[3, 4]]");
279 }
280
281 #[test]
286 fn test_sql_escape_map() {
287 use std::collections::BTreeMap;
288 let mut map = BTreeMap::new();
289 map.insert("name", "Alice");
290 map.insert("role", "admin");
291 let result = render_sql_value(Value::from(map));
292 assert!(result.starts_with("'"));
294 assert!(result.ends_with("'"));
295 assert!(result.contains("name"));
296 assert!(result.contains("Alice"));
297 }
298
299 #[test]
304 fn test_sql_escape_bytes() {
305 let bytes = Value::from_bytes(vec![0xDE, 0xAD, 0xBE, 0xEF]);
306 let result = render_sql_value(bytes);
307 assert_eq!(result, "'\\xdeadbeef'::bytea");
308 }
309
310 #[test]
311 fn test_sql_escape_bytes_empty() {
312 let bytes = Value::from_bytes(vec![]);
313 let result = render_sql_value(bytes);
314 assert_eq!(result, "'\\x'::bytea");
315 }
316
317 #[test]
322 fn test_sql_escape_for_non_sql_templates() {
323 let mut env = Environment::new();
324 env.set_auto_escape_callback(auto_escape_callback);
325 env.set_formatter(sql_escape_formatter);
326 env.add_template("test.txt", "{{ value }}").unwrap();
328 let tmpl = env.get_template("test.txt").unwrap();
329 let result = tmpl.render(context!(value => "hello")).unwrap();
330 assert_eq!(result, "'hello'");
332 }
333
334 #[test]
335 fn test_sql_safe_filter_bypasses_escaping() {
336 let mut env = Environment::new();
337 env.set_auto_escape_callback(auto_escape_callback);
338 env.set_formatter(sql_escape_formatter);
339 env.add_template("test.sql", "{{ value|safe }}").unwrap();
341 let tmpl = env.get_template("test.sql").unwrap();
342 let result = tmpl.render(context!(value => "raw SQL here")).unwrap();
343 assert_eq!(result, "raw SQL here");
345 }
346
347 #[test]
348 fn test_sql_escape_filter_fails_for_custom_format() {
349 let mut env = Environment::new();
351 env.set_auto_escape_callback(auto_escape_callback);
352 env.set_formatter(sql_escape_formatter);
353
354 env.add_template("test.sql", "{{ value|escape }}").unwrap();
355 let tmpl = env.get_template("test.sql").unwrap();
356
357 let result = tmpl.render(context!(value => "test"));
358 assert!(result.is_err());
359 assert!(result
360 .unwrap_err()
361 .to_string()
362 .contains("does not know how to format to custom format"));
363 }
364
365 #[test]
366 fn test_sql_escape_from_safe_string_bypasses_escaping() {
367 let mut env = Environment::new();
369 env.set_auto_escape_callback(auto_escape_callback);
370 env.set_formatter(sql_escape_formatter);
371
372 env.add_template("test.sql", "{{ value }}").unwrap();
373 let tmpl = env.get_template("test.sql").unwrap();
374
375 let safe_value = Value::from_safe_string("1 OR 1=1".to_string());
376 let result = tmpl.render(context!(value => safe_value)).unwrap();
377 assert_eq!(result, "1 OR 1=1");
378
379 let normal_value = Value::from("1 OR 1=1");
381 let result = tmpl.render(context!(value => normal_value)).unwrap();
382 assert_eq!(result, "'1 OR 1=1'");
383 }
384
385 #[test]
390 fn test_sql_escape_only_on_output_not_in_loops() {
391 let mut env = Environment::new();
392 env.set_auto_escape_callback(auto_escape_callback);
393 env.set_formatter(sql_escape_formatter);
394
395 let template =
396 r#"{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}"#;
397 env.add_template("test.sql", template).unwrap();
398 let tmpl = env.get_template("test.sql").unwrap();
399
400 let items = vec!["alice", "bob", "charlie"];
401 let result = tmpl.render(context!(items => items)).unwrap();
402 assert_eq!(result, "'alice', 'bob', 'charlie'");
403 }
404
405 #[test]
406 fn test_sql_escape_only_on_output_not_in_conditionals() {
407 let mut env = Environment::new();
408 env.set_auto_escape_callback(auto_escape_callback);
409 env.set_formatter(sql_escape_formatter);
410
411 let template = r#"{% if enabled %}{{ value }}{% else %}NULL{% endif %}"#;
412 env.add_template("test.sql", template).unwrap();
413 let tmpl = env.get_template("test.sql").unwrap();
414
415 let result = tmpl
416 .render(context!(enabled => true, value => "test"))
417 .unwrap();
418 assert_eq!(result, "'test'");
419
420 let result = tmpl
421 .render(context!(enabled => false, value => "test"))
422 .unwrap();
423 assert_eq!(result, "NULL");
424 }
425
426 #[test]
427 fn test_sql_escape_loop_over_map_keys() {
428 let mut env = Environment::new();
429 env.set_auto_escape_callback(auto_escape_callback);
430 env.set_formatter(sql_escape_formatter);
431
432 let template = r#"{% for key, val in data|items %}{{ key }} = {{ val }}{% if not loop.last %}, {% endif %}{% endfor %}"#;
433 env.add_template("test.sql", template).unwrap();
434 let tmpl = env.get_template("test.sql").unwrap();
435
436 use std::collections::BTreeMap;
437 let mut data = BTreeMap::new();
438 data.insert("name", "Alice");
439 data.insert("role", "admin");
440
441 let result = tmpl.render(context!(data => data)).unwrap();
442 assert_eq!(result, "'name' = 'Alice', 'role' = 'admin'");
443 }
444
445 #[test]
446 fn test_sql_escape_nested_loop() {
447 let mut env = Environment::new();
448 env.set_auto_escape_callback(auto_escape_callback);
449 env.set_formatter(sql_escape_formatter);
450
451 let template = r#"{% for row in rows %}({% for col in row %}{{ col }}{% if not loop.last %}, {% endif %}{% endfor %}){% if not loop.last %}, {% endif %}{% endfor %}"#;
452 env.add_template("test.sql", template).unwrap();
453 let tmpl = env.get_template("test.sql").unwrap();
454
455 let rows: Vec<Vec<&str>> = vec![vec!["a", "b"], vec!["c", "d"]];
456 let result = tmpl.render(context!(rows => rows)).unwrap();
457 assert_eq!(result, "('a', 'b'), ('c', 'd')");
458 }
459
460 #[test]
461 fn test_sql_escape_length_filter_works() {
462 let mut env = Environment::new();
463 env.set_auto_escape_callback(auto_escape_callback);
464 env.set_formatter(sql_escape_formatter);
465
466 let template = r#"{% if items|length > 0 %}{{ items|length }}{% else %}0{% endif %}"#;
467 env.add_template("test.sql", template).unwrap();
468 let tmpl = env.get_template("test.sql").unwrap();
469
470 let items = vec![1, 2, 3];
471 let result = tmpl.render(context!(items => items)).unwrap();
472 assert_eq!(result, "3");
473 }
474
475 #[test]
480 fn test_auto_escape_callback_all_files() {
481 assert_eq!(
483 auto_escape_callback("migration.sql"),
484 AutoEscape::Custom(FORMAT_NAME)
485 );
486 assert_eq!(
487 auto_escape_callback("path/to/file.sql"),
488 AutoEscape::Custom(FORMAT_NAME)
489 );
490 assert_eq!(
491 auto_escape_callback("file.txt"),
492 AutoEscape::Custom(FORMAT_NAME)
493 );
494 assert_eq!(
495 auto_escape_callback("file.html"),
496 AutoEscape::Custom(FORMAT_NAME)
497 );
498 }
499}