1pub use sqlparser::dialect::*;
2
3use std::fmt::Write;
4use std::ops::ControlFlow;
5
6use sqlparser::{ast::visit_expressions_mut, parser::Parser};
7
8#[derive(thiserror::Error, Debug)]
9pub enum Error {
10 #[error("failed to parse the query")]
11 FailedToParse(#[from] sqlparser::parser::ParserError),
12 #[error("failed to stringify the redacted query; check the memory usage")]
13 FailedToStringify(#[from] std::fmt::Error),
14}
15
16pub fn redact(dialect: &dyn Dialect, sql: &str) -> Result<String, Error> {
18 let statements = Parser::parse_sql(dialect, sql)?;
19 let mut redacted = String::with_capacity(sql.len());
20 for mut stmt in statements {
21 visit_expressions_mut(&mut stmt, |expr| {
22 if let sqlparser::ast::Expr::Value(value) = expr {
23 *value = sqlparser::ast::Value::Placeholder(String::from("?"));
24 }
25 ControlFlow::<()>::Continue(())
26 });
27 write!(redacted, "{};", stmt)?;
28 }
29
30 Ok(redacted)
31}
32
33#[cfg(test)]
34mod tests {
35 use super::*;
36
37 #[test]
38 fn test_mysql() {
39 let redacted = redact(&MySqlDialect {}, "SELECT * FROM foo WHERE bar = 1").unwrap();
40 assert_eq!(redacted, "SELECT * FROM foo WHERE bar = ?;");
41
42 let redacted = redact(
43 &MySqlDialect {},
44 "SELECT user, article, 5 FROM articles WHERE user = 100 AND is_deleted = false",
45 )
46 .unwrap();
47 assert_eq!(
48 redacted,
49 "SELECT user, article, ? FROM articles WHERE user = ? AND is_deleted = ?;"
50 );
51
52 let redacted = redact(
53 &MySqlDialect {},
54 "SELECT * FROM users
55 WHERE age > 18
56 AND city = 'New York'
57 ORDER BY last_name ASC;",
58 )
59 .unwrap();
60 assert_eq!(
61 redacted,
62 "SELECT * FROM users WHERE age > ? AND city = ? ORDER BY last_name ASC;"
63 );
64
65 let redacted = redact(
66 &MySqlDialect {},
67 "UPDATE customers
68 SET email = 'newemail@example.com', last_purchase_date = '2022-03-31'
69 WHERE customer_id = 12345;",
70 )
71 .unwrap();
72 assert_eq!(
73 redacted,
74 "UPDATE customers SET email = ?, last_purchase_date = ? WHERE customer_id = ?;"
75 );
76
77 let redacted = redact(
78 &MySqlDialect {},
79 "INSERT INTO users (name, email, age)
80 VALUES ('John Doe', 'johndoe@example.com', 25);",
81 )
82 .unwrap();
83 assert_eq!(
84 redacted,
85 "INSERT INTO users (name, email, age) VALUES (?, ?, ?);"
86 );
87
88 let redacted = redact(
89 &MySqlDialect {},
90 "DELETE FROM users
91 WHERE email = 'johndoe@example.com';",
92 )
93 .unwrap();
94 assert_eq!(redacted, "DELETE FROM users WHERE email = ?;");
95
96 let redacted = redact(
97 &MySqlDialect {},
98 "SELECT c.name AS category, AVG(p.price) AS avg_price
99 FROM products p
100 JOIN categories c ON p.category_id = c.id
101 GROUP BY c.name;",
102 )
103 .unwrap();
104 assert_eq!(redacted, "SELECT c.name AS category, AVG(p.price) AS avg_price FROM products AS p JOIN categories AS c ON p.category_id = c.id GROUP BY c.name;");
105
106 let redacted = redact(
107 &MySqlDialect {},
108 "SELECT name, email, age
109 FROM users
110 WHERE age BETWEEN 18 AND 30
111 ORDER BY RAND() LIMIT 5;",
112 )
113 .unwrap();
114 assert_eq!(
115 redacted,
116 "SELECT name, email, age FROM users WHERE age BETWEEN ? AND ? ORDER BY RAND() LIMIT ?;"
117 );
118
119 let redacted = redact(
120 &MySqlDialect {},
121 "SELECT c.name AS category, AVG(p.price) AS avg_price
122 FROM products p
123 JOIN categories c ON p.category_id = c.id
124 GROUP BY c.name;
125 SELECT name, email, age
126 FROM users
127 WHERE age BETWEEN 18 AND 30
128 ORDER BY RAND() LIMIT 5;",
129 )
130 .unwrap();
131 assert_eq!(
132 redacted,
133 "SELECT c.name AS category, AVG(p.price) AS avg_price FROM products AS p JOIN categories AS c ON p.category_id = c.id GROUP BY c.name;SELECT name, email, age FROM users WHERE age BETWEEN ? AND ? ORDER BY RAND() LIMIT ?;"
134 );
135
136 assert!(matches!(
137 redact(&MySqlDialect {}, "this is not a sql."),
138 Err(Error::FailedToParse(_))
139 ));
140 }
141}