1mod config;
11mod config_generator;
12mod matcher;
13mod rewriter;
14pub mod strategy;
15
16pub use config::RedactConfig;
17#[allow(unused_imports)]
19pub use config::{RedactConfigBuilder, RedactYamlConfig, Rule};
20pub use config_generator::generate_config;
21pub use matcher::ColumnMatcher;
22pub use rewriter::ValueRewriter;
23pub use strategy::StrategyKind;
24
25use crate::parser::postgres_copy::parse_copy_columns;
26use crate::parser::{Parser, SqlDialect, StatementType};
27use crate::schema::{Schema, SchemaBuilder};
28use ahash::AHashMap;
29use schemars::JsonSchema;
30use std::fs::File;
31use std::io::{BufWriter, Write};
32use std::path::Path;
33
34#[derive(Debug, Default, serde::Serialize, JsonSchema)]
36pub struct RedactStats {
37 pub tables_processed: usize,
39 pub rows_redacted: u64,
41 pub columns_redacted: u64,
43 pub table_stats: Vec<TableRedactStats>,
45 pub warnings: Vec<String>,
47}
48
49#[derive(Debug, Clone, serde::Serialize, JsonSchema)]
51pub struct TableRedactStats {
52 pub name: String,
53 pub rows_processed: u64,
54 pub columns_redacted: u64,
55}
56
57pub struct Redactor {
59 config: RedactConfig,
60 schema: Schema,
61 matcher: ColumnMatcher,
62 rewriter: ValueRewriter,
63 stats: RedactStats,
64 pending_copy: Option<PendingCopy>,
66}
67
68struct PendingCopy {
70 header: Vec<u8>,
71 table_name: String,
72 columns: Vec<String>,
73}
74
75impl Redactor {
76 pub fn new(config: RedactConfig) -> anyhow::Result<Self> {
78 let schema = Self::build_schema(&config.input, config.dialect)?;
80
81 let matcher = ColumnMatcher::from_config(&config)?;
83
84 let rewriter = ValueRewriter::new(config.seed, config.dialect, config.locale.clone());
86
87 Ok(Self {
88 config,
89 schema,
90 matcher,
91 rewriter,
92 stats: RedactStats::default(),
93 pending_copy: None,
94 })
95 }
96
97 fn build_schema(input: &Path, dialect: SqlDialect) -> anyhow::Result<Schema> {
99 let file = File::open(input)?;
100 let mut parser = Parser::with_dialect(file, 64 * 1024, dialect);
101 let mut builder = SchemaBuilder::new();
102
103 while let Some(stmt) = parser.read_statement()? {
104 let (stmt_type, _table_name) =
105 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
106
107 if stmt_type == StatementType::CreateTable {
108 let stmt_str = String::from_utf8_lossy(&stmt);
109 builder.parse_create_table(&stmt_str);
110 }
111 }
112
113 Ok(builder.build())
114 }
115
116 pub fn run(&mut self) -> anyhow::Result<RedactStats> {
118 if self.config.dry_run {
119 return self.dry_run();
120 }
121
122 let output: Box<dyn Write> = if let Some(ref path) = self.config.output {
124 Box::new(BufWriter::new(File::create(path)?))
125 } else {
126 Box::new(std::io::stdout())
127 };
128
129 self.process_file(output)?;
130
131 Ok(std::mem::take(&mut self.stats))
132 }
133
134 fn dry_run(&mut self) -> anyhow::Result<RedactStats> {
136 let file = File::open(&self.config.input)?;
137 let mut parser = Parser::with_dialect(file, 64 * 1024, self.config.dialect);
138
139 let mut tables_seen: AHashMap<String, u64> = AHashMap::new();
140
141 while let Some(stmt) = parser.read_statement()? {
142 let (stmt_type, table_name) =
143 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, self.config.dialect);
144
145 if !table_name.is_empty()
146 && (stmt_type == StatementType::Insert || stmt_type == StatementType::Copy)
147 {
148 *tables_seen.entry(table_name).or_insert(0) += 1;
149 }
150 }
151
152 for (name, count) in tables_seen {
154 if let Some(table) = self.schema.get_table(&name) {
155 let columns_matched = self.matcher.count_matches(&name, table);
156 if columns_matched > 0 {
157 self.stats.tables_processed += 1;
158 self.stats.rows_redacted += count;
159 self.stats.columns_redacted += columns_matched as u64 * count;
160 self.stats.table_stats.push(TableRedactStats {
161 name,
162 rows_processed: count,
163 columns_redacted: columns_matched as u64,
164 });
165 }
166 }
167 }
168
169 Ok(std::mem::take(&mut self.stats))
170 }
171
172 fn process_file(&mut self, mut output: Box<dyn Write>) -> anyhow::Result<()> {
174 let file = File::open(&self.config.input)?;
175 let mut parser = Parser::with_dialect(file, 64 * 1024, self.config.dialect);
176
177 while let Some(stmt) = parser.read_statement()? {
178 let (stmt_type, table_name) =
179 Parser::<&[u8]>::parse_statement_with_dialect(&stmt, self.config.dialect);
180
181 let redacted = match stmt_type {
182 StatementType::Insert if !table_name.is_empty() => {
183 self.redact_insert(&stmt, &table_name)?
184 }
185 StatementType::Copy if !table_name.is_empty() => {
186 if self.config.dialect == SqlDialect::Postgres {
188 let header_str = String::from_utf8_lossy(&stmt);
189 let columns = parse_copy_columns(&header_str);
190 self.pending_copy = Some(PendingCopy {
191 header: stmt.clone(),
192 table_name: table_name.clone(),
193 columns,
194 });
195 continue;
197 }
198 self.redact_copy(&stmt, &table_name)?
199 }
200 StatementType::Unknown
201 if self.config.dialect == SqlDialect::Postgres
202 && self.pending_copy.is_some()
203 && (stmt.ends_with(b"\\.\n") || stmt.ends_with(b"\\.\r\n")) =>
204 {
205 self.redact_copy_data(&stmt)?
207 }
208 _ => {
209 if let Some(pending) = self.pending_copy.take() {
212 output.write_all(&pending.header)?;
213 }
214 stmt
215 }
216 };
217
218 output.write_all(&redacted)?;
219 }
220
221 if let Some(pending) = self.pending_copy.take() {
223 output.write_all(&pending.header)?;
224 }
225
226 output.flush()?;
227 Ok(())
228 }
229
230 fn redact_insert(&mut self, stmt: &[u8], table_name: &str) -> anyhow::Result<Vec<u8>> {
232 if self.should_skip_table(table_name) {
234 return Ok(stmt.to_vec());
235 }
236
237 let Some(table) = self.schema.get_table(table_name) else {
239 self.stats.warnings.push(format!(
240 "No schema for table '{}', passing through unchanged",
241 table_name
242 ));
243 return Ok(stmt.to_vec());
244 };
245
246 let strategies = self.matcher.get_strategies(table_name, table);
248
249 if strategies.iter().all(|s| matches!(s, StrategyKind::Skip)) {
251 return Ok(stmt.to_vec());
252 }
253
254 let (redacted, rows_redacted, cols_redacted) =
256 self.rewriter
257 .rewrite_insert(stmt, table_name, table, &strategies)?;
258
259 if rows_redacted > 0 {
261 self.stats.rows_redacted += rows_redacted;
262 self.stats.columns_redacted += cols_redacted;
263
264 if let Some(ts) = self
266 .stats
267 .table_stats
268 .iter_mut()
269 .find(|t| t.name == table_name)
270 {
271 ts.rows_processed += rows_redacted;
272 ts.columns_redacted += cols_redacted;
273 } else {
274 self.stats.tables_processed += 1;
275 self.stats.table_stats.push(TableRedactStats {
276 name: table_name.to_string(),
277 rows_processed: rows_redacted,
278 columns_redacted: cols_redacted,
279 });
280 }
281 }
282
283 Ok(redacted)
284 }
285
286 fn redact_copy(&mut self, stmt: &[u8], table_name: &str) -> anyhow::Result<Vec<u8>> {
288 if self.should_skip_table(table_name) {
290 return Ok(stmt.to_vec());
291 }
292
293 let Some(table) = self.schema.get_table(table_name) else {
295 self.stats.warnings.push(format!(
296 "No schema for table '{}', passing through unchanged",
297 table_name
298 ));
299 return Ok(stmt.to_vec());
300 };
301
302 let strategies = self.matcher.get_strategies(table_name, table);
304
305 if strategies.iter().all(|s| matches!(s, StrategyKind::Skip)) {
307 return Ok(stmt.to_vec());
308 }
309
310 let (redacted, rows_redacted, cols_redacted) =
312 self.rewriter
313 .rewrite_copy(stmt, table_name, table, &strategies)?;
314
315 if rows_redacted > 0 {
317 self.stats.rows_redacted += rows_redacted;
318 self.stats.columns_redacted += cols_redacted;
319
320 if let Some(ts) = self
322 .stats
323 .table_stats
324 .iter_mut()
325 .find(|t| t.name == table_name)
326 {
327 ts.rows_processed += rows_redacted;
328 ts.columns_redacted += cols_redacted;
329 } else {
330 self.stats.tables_processed += 1;
331 self.stats.table_stats.push(TableRedactStats {
332 name: table_name.to_string(),
333 rows_processed: rows_redacted,
334 columns_redacted: cols_redacted,
335 });
336 }
337 }
338
339 Ok(redacted)
340 }
341
342 fn redact_copy_data(&mut self, data_block: &[u8]) -> anyhow::Result<Vec<u8>> {
344 let pending = self
345 .pending_copy
346 .take()
347 .ok_or_else(|| anyhow::anyhow!("COPY data block without pending header"))?;
348
349 let table_name = &pending.table_name;
350
351 if self.should_skip_table(table_name) {
353 let mut result = pending.header;
355 result.extend_from_slice(data_block);
356 return Ok(result);
357 }
358
359 let Some(table) = self.schema.get_table(table_name) else {
361 self.stats.warnings.push(format!(
362 "No schema for table '{}', passing through unchanged",
363 table_name
364 ));
365 let mut result = pending.header;
366 result.extend_from_slice(data_block);
367 return Ok(result);
368 };
369
370 let strategies = self.matcher.get_strategies(table_name, table);
372
373 if strategies.iter().all(|s| matches!(s, StrategyKind::Skip)) {
375 let mut result = pending.header;
376 result.extend_from_slice(data_block);
377 return Ok(result);
378 }
379
380 let (redacted_data, rows_redacted, cols_redacted) =
382 self.rewriter
383 .rewrite_copy_data(data_block, table, &strategies, &pending.columns)?;
384
385 if rows_redacted > 0 {
387 self.stats.rows_redacted += rows_redacted;
388 self.stats.columns_redacted += cols_redacted;
389
390 if let Some(ts) = self
391 .stats
392 .table_stats
393 .iter_mut()
394 .find(|t| t.name == *table_name)
395 {
396 ts.rows_processed += rows_redacted;
397 ts.columns_redacted += cols_redacted;
398 } else {
399 self.stats.tables_processed += 1;
400 self.stats.table_stats.push(TableRedactStats {
401 name: table_name.to_string(),
402 rows_processed: rows_redacted,
403 columns_redacted: cols_redacted,
404 });
405 }
406 }
407
408 let mut result = pending.header;
411 if !result.ends_with(b"\n") {
412 result.push(b'\n');
413 }
414 result.extend_from_slice(&redacted_data);
415 Ok(result)
416 }
417
418 fn should_skip_table(&self, name: &str) -> bool {
420 if self
422 .config
423 .exclude
424 .iter()
425 .any(|e| e.eq_ignore_ascii_case(name))
426 {
427 return true;
428 }
429
430 if let Some(ref tables) = self.config.tables_filter {
432 if !tables.iter().any(|t| t.eq_ignore_ascii_case(name)) {
433 return true;
434 }
435 }
436
437 false
438 }
439}