rustrails_record/
sanitization.rs1use std::collections::HashMap;
2
3use serde_json::Value;
4
5#[must_use]
7pub fn sanitize_sql(template: &str, binds: &[Value]) -> String {
8 let mut binds = binds.iter();
9 let mut sanitized = String::with_capacity(template.len() + binds.len() * 8);
10
11 for character in template.chars() {
12 if character == '?' {
13 if let Some(bind) = binds.next() {
14 sanitized.push_str(&sql_literal(bind));
15 } else {
16 sanitized.push('?');
17 }
18 } else {
19 sanitized.push(character);
20 }
21 }
22
23 sanitized
24}
25
26#[must_use]
28pub fn sanitize_sql_like(input: &str) -> String {
29 let mut sanitized = String::with_capacity(input.len());
30 for character in input.chars() {
31 match character {
32 '\\' | '%' | '_' => {
33 sanitized.push('\\');
34 sanitized.push(character);
35 }
36 _ => sanitized.push(character),
37 }
38 }
39 sanitized
40}
41
42#[must_use]
44pub fn sanitize_sql_array(values: &[Value]) -> String {
45 values
46 .iter()
47 .map(sql_literal)
48 .collect::<Vec<_>>()
49 .join(", ")
50}
51
52#[must_use]
54pub fn sanitize_sql_hash(hash: &HashMap<String, Value>) -> String {
55 let mut pairs = hash.iter().collect::<Vec<_>>();
56 pairs.sort_by(|left, right| left.0.cmp(right.0));
57
58 pairs
59 .into_iter()
60 .map(|(key, value)| {
61 if value.is_null() {
62 format!("{key} IS NULL")
63 } else {
64 format!("{key} = {}", sql_literal(value))
65 }
66 })
67 .collect::<Vec<_>>()
68 .join(" AND ")
69}
70
71fn sql_literal(value: &Value) -> String {
72 match value {
73 Value::Null => "NULL".to_owned(),
74 Value::Bool(flag) => {
75 if *flag {
76 "TRUE".to_owned()
77 } else {
78 "FALSE".to_owned()
79 }
80 }
81 Value::Number(number) => number.to_string(),
82 Value::String(text) => format!("'{}'", text.replace('\'', "''")),
83 Value::Array(_) | Value::Object(_) => match serde_json::to_string(value) {
84 Ok(serialized) => format!("'{}'", serialized.replace('\'', "''")),
85 Err(_) => "NULL".to_owned(),
86 },
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use std::collections::HashMap;
93
94 use serde_json::{Value, json};
95
96 use super::{sanitize_sql, sanitize_sql_array, sanitize_sql_hash, sanitize_sql_like};
97
98 #[test]
99 fn sanitize_sql_replaces_placeholders_in_order() {
100 let sql = sanitize_sql("name = ? AND age >= ?", &[json!("Alice"), json!(21)]);
101
102 assert_eq!(sql, "name = 'Alice' AND age >= 21");
103 }
104
105 #[test]
106 fn sanitize_sql_leaves_extra_placeholders_when_binds_run_out() {
107 let sql = sanitize_sql("name = ? AND age = ?", &[json!("Alice")]);
108 assert_eq!(sql, "name = 'Alice' AND age = ?");
109 }
110
111 #[test]
112 fn sanitize_sql_ignores_extra_binds() {
113 let sql = sanitize_sql("id = ?", &[json!(1), json!(2)]);
114 assert_eq!(sql, "id = 1");
115 }
116
117 #[test]
118 fn sanitize_sql_escapes_string_quotes() {
119 let sql = sanitize_sql("name = ?", &[json!("O'Brien")]);
120 assert_eq!(sql, "name = 'O''Brien'");
121 }
122
123 #[test]
124 fn sanitize_sql_handles_null_boolean_and_numbers() {
125 let sql = sanitize_sql(
126 "deleted_at IS ? OR active = ? OR score = ?",
127 &[Value::Null, json!(true), json!(12.5)],
128 );
129
130 assert_eq!(sql, "deleted_at IS NULL OR active = TRUE OR score = 12.5");
131 }
132
133 #[test]
134 fn sanitize_sql_serializes_json_values() {
135 let sql = sanitize_sql("payload = ?", &[json!({"role": "admin"})]);
136 assert_eq!(sql, "payload = '{\"role\":\"admin\"}'");
137 }
138
139 #[test]
140 fn sanitize_sql_neutralizes_common_injection_payloads() {
141 let payload = "' OR 1=1 --";
142 let sql = sanitize_sql("name = ?", &[json!(payload)]);
143
144 assert_eq!(sql, "name = ''' OR 1=1 --'");
145 assert!(!sql.contains("name = ' OR 1=1 --"));
146 }
147
148 #[test]
149 fn sanitize_sql_like_escapes_percent_underscore_and_backslash() {
150 assert_eq!(
151 sanitize_sql_like("100%_done\\today"),
152 "100\\%\\_done\\\\today"
153 );
154 }
155
156 #[test]
157 fn sanitize_sql_like_leaves_safe_text_unchanged() {
158 assert_eq!(sanitize_sql_like("plain-text"), "plain-text");
159 }
160
161 #[test]
162 fn sanitize_sql_array_joins_values() {
163 let sql = sanitize_sql_array(&[json!(1), json!("Alice"), Value::Null]);
164 assert_eq!(sql, "1, 'Alice', NULL");
165 }
166
167 #[test]
168 fn sanitize_sql_array_returns_empty_string_for_empty_input() {
169 assert_eq!(sanitize_sql_array(&[]), "");
170 }
171
172 #[test]
173 fn sanitize_sql_hash_sorts_keys_for_determinism() {
174 let hash = HashMap::from([
175 ("name".to_owned(), json!("Alice")),
176 ("age".to_owned(), json!(30)),
177 ]);
178
179 assert_eq!(sanitize_sql_hash(&hash), "age = 30 AND name = 'Alice'");
180 }
181
182 #[test]
183 fn sanitize_sql_hash_uses_is_null_for_null_values() {
184 let hash = HashMap::from([("deleted_at".to_owned(), Value::Null)]);
185 assert_eq!(sanitize_sql_hash(&hash), "deleted_at IS NULL");
186 }
187
188 #[test]
189 fn sanitize_sql_hash_escapes_injection_strings() {
190 let hash = HashMap::from([("name".to_owned(), json!("Robert'); DROP TABLE users;--"))]);
191 assert_eq!(
192 sanitize_sql_hash(&hash),
193 "name = 'Robert''); DROP TABLE users;--'"
194 );
195 }
196
197 macro_rules! sanitize_sql_case {
198 ($name:ident, $value:expr, $expected:expr) => {
199 #[test]
200 fn $name() {
201 assert_eq!(sanitize_sql("value = ?", &[$value]), $expected);
202 }
203 };
204 }
205
206 sanitize_sql_case!(sanitize_sql_string_case, json!("hello"), "value = 'hello'");
207 sanitize_sql_case!(sanitize_sql_integer_case, json!(42), "value = 42");
208 sanitize_sql_case!(sanitize_sql_float_case, json!(2.72), "value = 2.72");
209 sanitize_sql_case!(sanitize_sql_false_case, json!(false), "value = FALSE");
210 sanitize_sql_case!(sanitize_sql_null_case, Value::Null, "value = NULL");
211}