Skip to main content

rustrails_record/
sanitization.rs

1use std::collections::HashMap;
2
3use serde_json::Value;
4
5/// Replaces `?` placeholders in an SQL template with sanitized literal values.
6#[must_use]
7pub fn sanitize_sql(template: &str, binds: &[Value]) -> String {
8    let mut binds = binds.iter();
9    let mut sanitized = String::with_capacity(template.len() + binds.len() * 8);
10
11    for character in template.chars() {
12        if character == '?' {
13            if let Some(bind) = binds.next() {
14                sanitized.push_str(&sql_literal(bind));
15            } else {
16                sanitized.push('?');
17            }
18        } else {
19            sanitized.push(character);
20        }
21    }
22
23    sanitized
24}
25
26/// Escapes SQL `LIKE` wildcard characters using backslash escapes.
27#[must_use]
28pub fn sanitize_sql_like(input: &str) -> String {
29    let mut sanitized = String::with_capacity(input.len());
30    for character in input.chars() {
31        match character {
32            '\\' | '%' | '_' => {
33                sanitized.push('\\');
34                sanitized.push(character);
35            }
36            _ => sanitized.push(character),
37        }
38    }
39    sanitized
40}
41
42/// Joins SQL literals into a comma-separated list.
43#[must_use]
44pub fn sanitize_sql_array(values: &[Value]) -> String {
45    values
46        .iter()
47        .map(sql_literal)
48        .collect::<Vec<_>>()
49        .join(", ")
50}
51
52/// Builds deterministic `key = value` predicates joined by `AND`.
53#[must_use]
54pub fn sanitize_sql_hash(hash: &HashMap<String, Value>) -> String {
55    let mut pairs = hash.iter().collect::<Vec<_>>();
56    pairs.sort_by(|left, right| left.0.cmp(right.0));
57
58    pairs
59        .into_iter()
60        .map(|(key, value)| {
61            if value.is_null() {
62                format!("{key} IS NULL")
63            } else {
64                format!("{key} = {}", sql_literal(value))
65            }
66        })
67        .collect::<Vec<_>>()
68        .join(" AND ")
69}
70
71fn sql_literal(value: &Value) -> String {
72    match value {
73        Value::Null => "NULL".to_owned(),
74        Value::Bool(flag) => {
75            if *flag {
76                "TRUE".to_owned()
77            } else {
78                "FALSE".to_owned()
79            }
80        }
81        Value::Number(number) => number.to_string(),
82        Value::String(text) => format!("'{}'", text.replace('\'', "''")),
83        Value::Array(_) | Value::Object(_) => match serde_json::to_string(value) {
84            Ok(serialized) => format!("'{}'", serialized.replace('\'', "''")),
85            Err(_) => "NULL".to_owned(),
86        },
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use std::collections::HashMap;
93
94    use serde_json::{Value, json};
95
96    use super::{sanitize_sql, sanitize_sql_array, sanitize_sql_hash, sanitize_sql_like};
97
98    #[test]
99    fn sanitize_sql_replaces_placeholders_in_order() {
100        let sql = sanitize_sql("name = ? AND age >= ?", &[json!("Alice"), json!(21)]);
101
102        assert_eq!(sql, "name = 'Alice' AND age >= 21");
103    }
104
105    #[test]
106    fn sanitize_sql_leaves_extra_placeholders_when_binds_run_out() {
107        let sql = sanitize_sql("name = ? AND age = ?", &[json!("Alice")]);
108        assert_eq!(sql, "name = 'Alice' AND age = ?");
109    }
110
111    #[test]
112    fn sanitize_sql_ignores_extra_binds() {
113        let sql = sanitize_sql("id = ?", &[json!(1), json!(2)]);
114        assert_eq!(sql, "id = 1");
115    }
116
117    #[test]
118    fn sanitize_sql_escapes_string_quotes() {
119        let sql = sanitize_sql("name = ?", &[json!("O'Brien")]);
120        assert_eq!(sql, "name = 'O''Brien'");
121    }
122
123    #[test]
124    fn sanitize_sql_handles_null_boolean_and_numbers() {
125        let sql = sanitize_sql(
126            "deleted_at IS ? OR active = ? OR score = ?",
127            &[Value::Null, json!(true), json!(12.5)],
128        );
129
130        assert_eq!(sql, "deleted_at IS NULL OR active = TRUE OR score = 12.5");
131    }
132
133    #[test]
134    fn sanitize_sql_serializes_json_values() {
135        let sql = sanitize_sql("payload = ?", &[json!({"role": "admin"})]);
136        assert_eq!(sql, "payload = '{\"role\":\"admin\"}'");
137    }
138
139    #[test]
140    fn sanitize_sql_neutralizes_common_injection_payloads() {
141        let payload = "' OR 1=1 --";
142        let sql = sanitize_sql("name = ?", &[json!(payload)]);
143
144        assert_eq!(sql, "name = ''' OR 1=1 --'");
145        assert!(!sql.contains("name = ' OR 1=1 --"));
146    }
147
148    #[test]
149    fn sanitize_sql_like_escapes_percent_underscore_and_backslash() {
150        assert_eq!(
151            sanitize_sql_like("100%_done\\today"),
152            "100\\%\\_done\\\\today"
153        );
154    }
155
156    #[test]
157    fn sanitize_sql_like_leaves_safe_text_unchanged() {
158        assert_eq!(sanitize_sql_like("plain-text"), "plain-text");
159    }
160
161    #[test]
162    fn sanitize_sql_array_joins_values() {
163        let sql = sanitize_sql_array(&[json!(1), json!("Alice"), Value::Null]);
164        assert_eq!(sql, "1, 'Alice', NULL");
165    }
166
167    #[test]
168    fn sanitize_sql_array_returns_empty_string_for_empty_input() {
169        assert_eq!(sanitize_sql_array(&[]), "");
170    }
171
172    #[test]
173    fn sanitize_sql_hash_sorts_keys_for_determinism() {
174        let hash = HashMap::from([
175            ("name".to_owned(), json!("Alice")),
176            ("age".to_owned(), json!(30)),
177        ]);
178
179        assert_eq!(sanitize_sql_hash(&hash), "age = 30 AND name = 'Alice'");
180    }
181
182    #[test]
183    fn sanitize_sql_hash_uses_is_null_for_null_values() {
184        let hash = HashMap::from([("deleted_at".to_owned(), Value::Null)]);
185        assert_eq!(sanitize_sql_hash(&hash), "deleted_at IS NULL");
186    }
187
188    #[test]
189    fn sanitize_sql_hash_escapes_injection_strings() {
190        let hash = HashMap::from([("name".to_owned(), json!("Robert'); DROP TABLE users;--"))]);
191        assert_eq!(
192            sanitize_sql_hash(&hash),
193            "name = 'Robert''); DROP TABLE users;--'"
194        );
195    }
196
197    macro_rules! sanitize_sql_case {
198        ($name:ident, $value:expr, $expected:expr) => {
199            #[test]
200            fn $name() {
201                assert_eq!(sanitize_sql("value = ?", &[$value]), $expected);
202            }
203        };
204    }
205
206    sanitize_sql_case!(sanitize_sql_string_case, json!("hello"), "value = 'hello'");
207    sanitize_sql_case!(sanitize_sql_integer_case, json!(42), "value = 42");
208    sanitize_sql_case!(sanitize_sql_float_case, json!(2.72), "value = 2.72");
209    sanitize_sql_case!(sanitize_sql_false_case, json!(false), "value = FALSE");
210    sanitize_sql_case!(sanitize_sql_null_case, Value::Null, "value = NULL");
211}