sql_redactor/
lib.rs

1pub use sqlparser::dialect::*;
2
3use std::fmt::Write;
4use std::ops::ControlFlow;
5
6use sqlparser::{ast::visit_expressions_mut, parser::Parser};
7
8#[derive(thiserror::Error, Debug)]
9pub enum Error {
10    #[error("failed to parse the query")]
11    FailedToParse(#[from] sqlparser::parser::ParserError),
12    #[error("failed to stringify the redacted query; check the memory usage")]
13    FailedToStringify(#[from] std::fmt::Error),
14}
15
16/// Replace all Value nodes in an AST with '?'
17pub fn redact(dialect: &dyn Dialect, sql: &str) -> Result<String, Error> {
18    let statements = Parser::parse_sql(dialect, sql)?;
19    let mut redacted = String::with_capacity(sql.len());
20    for mut stmt in statements {
21        visit_expressions_mut(&mut stmt, |expr| {
22            if let sqlparser::ast::Expr::Value(value) = expr {
23                *value = sqlparser::ast::Value::Placeholder(String::from("?"));
24            }
25            ControlFlow::<()>::Continue(())
26        });
27        write!(redacted, "{};", stmt)?;
28    }
29
30    Ok(redacted)
31}
32
33#[cfg(test)]
34mod tests {
35    use super::*;
36
37    #[test]
38    fn test_mysql() {
39        let redacted = redact(&MySqlDialect {}, "SELECT * FROM foo WHERE bar = 1").unwrap();
40        assert_eq!(redacted, "SELECT * FROM foo WHERE bar = ?;");
41
42        let redacted = redact(
43            &MySqlDialect {},
44            "SELECT user, article, 5 FROM articles WHERE user = 100 AND is_deleted = false",
45        )
46        .unwrap();
47        assert_eq!(
48            redacted,
49            "SELECT user, article, ? FROM articles WHERE user = ? AND is_deleted = ?;"
50        );
51
52        let redacted = redact(
53            &MySqlDialect {},
54            "SELECT * FROM users 
55            WHERE age > 18 
56            AND city = 'New York' 
57            ORDER BY last_name ASC;",
58        )
59        .unwrap();
60        assert_eq!(
61            redacted,
62            "SELECT * FROM users WHERE age > ? AND city = ? ORDER BY last_name ASC;"
63        );
64
65        let redacted = redact(
66            &MySqlDialect {},
67            "UPDATE customers 
68            SET email = 'newemail@example.com', last_purchase_date = '2022-03-31' 
69            WHERE customer_id = 12345;",
70        )
71        .unwrap();
72        assert_eq!(
73            redacted,
74            "UPDATE customers SET email = ?, last_purchase_date = ? WHERE customer_id = ?;"
75        );
76
77        let redacted = redact(
78            &MySqlDialect {},
79            "INSERT INTO users (name, email, age) 
80            VALUES ('John Doe', 'johndoe@example.com', 25);",
81        )
82        .unwrap();
83        assert_eq!(
84            redacted,
85            "INSERT INTO users (name, email, age) VALUES (?, ?, ?);"
86        );
87
88        let redacted = redact(
89            &MySqlDialect {},
90            "DELETE FROM users 
91            WHERE email = 'johndoe@example.com';",
92        )
93        .unwrap();
94        assert_eq!(redacted, "DELETE FROM users WHERE email = ?;");
95
96        let redacted = redact(
97            &MySqlDialect {},
98            "SELECT c.name AS category, AVG(p.price) AS avg_price 
99            FROM products p 
100            JOIN categories c ON p.category_id = c.id 
101            GROUP BY c.name;",
102        )
103        .unwrap();
104        assert_eq!(redacted, "SELECT c.name AS category, AVG(p.price) AS avg_price FROM products AS p JOIN categories AS c ON p.category_id = c.id GROUP BY c.name;");
105
106        let redacted = redact(
107            &MySqlDialect {},
108            "SELECT name, email, age 
109            FROM users 
110            WHERE age BETWEEN 18 AND 30 
111            ORDER BY RAND() LIMIT 5;",
112        )
113        .unwrap();
114        assert_eq!(
115            redacted,
116            "SELECT name, email, age FROM users WHERE age BETWEEN ? AND ? ORDER BY RAND() LIMIT ?;"
117        );
118
119        let redacted = redact(
120            &MySqlDialect {},
121            "SELECT c.name AS category, AVG(p.price) AS avg_price 
122            FROM products p 
123            JOIN categories c ON p.category_id = c.id 
124            GROUP BY c.name;
125            SELECT name, email, age 
126            FROM users 
127            WHERE age BETWEEN 18 AND 30 
128            ORDER BY RAND() LIMIT 5;",
129        )
130        .unwrap();
131        assert_eq!(
132            redacted,
133            "SELECT c.name AS category, AVG(p.price) AS avg_price FROM products AS p JOIN categories AS c ON p.category_id = c.id GROUP BY c.name;SELECT name, email, age FROM users WHERE age BETWEEN ? AND ? ORDER BY RAND() LIMIT ?;"
134        );
135
136        assert!(matches!(
137            redact(&MySqlDialect {}, "this is not a sql."),
138            Err(Error::FailedToParse(_))
139        ));
140    }
141}