1use crate::parser::mysql_insert::{InsertParser, ParsedValue};
7use crate::parser::postgres_copy::{parse_copy_columns, CopyParser};
8use crate::parser::SqlDialect;
9use crate::redactor::strategy::{
10 ConstantStrategy, FakeStrategy, HashStrategy, MaskStrategy, NullStrategy, RedactValue,
11 Strategy, StrategyKind,
12};
13use crate::schema::TableSchema;
14use rand::rngs::StdRng;
15use rand::SeedableRng;
16
17pub struct ValueRewriter {
19 rng: StdRng,
21 dialect: SqlDialect,
23 locale: String,
25}
26
27impl ValueRewriter {
28 pub fn new(seed: Option<u64>, dialect: SqlDialect, locale: String) -> Self {
30 let rng = match seed {
31 Some(s) => StdRng::seed_from_u64(s),
32 None => StdRng::from_os_rng(),
33 };
34 Self {
35 rng,
36 dialect,
37 locale,
38 }
39 }
40
41 pub fn rewrite_insert(
43 &mut self,
44 stmt: &[u8],
45 table_name: &str,
46 table: &TableSchema,
47 strategies: &[StrategyKind],
48 ) -> anyhow::Result<(Vec<u8>, u64, u64)> {
49 let mut parser = InsertParser::new(stmt).with_schema(table);
51 let rows = parser.parse_rows()?;
52
53 if rows.is_empty() {
54 return Ok((stmt.to_vec(), 0, 0));
55 }
56
57 let stmt_str = String::from_utf8_lossy(stmt);
59 let column_list = self.extract_column_list(&stmt_str);
60
61 let mut result = self.build_insert_header(table_name, &column_list);
63
64 let mut rows_redacted = 0u64;
65 let mut columns_redacted = 0u64;
66 let num_strategies = strategies.len();
67
68 for (row_idx, row) in rows.iter().enumerate() {
69 if row_idx > 0 {
70 result.extend_from_slice(b",");
71 }
72 result.extend_from_slice(b"\n(");
73
74 let mut row_had_redaction = false;
75
76 for (col_idx, value) in row.values.iter().enumerate() {
77 if col_idx > 0 {
78 result.extend_from_slice(b", ");
79 }
80
81 let strategy = strategies.get(col_idx).unwrap_or(&StrategyKind::Skip);
83
84 let (redacted_sql, was_redacted) =
86 self.redact_value(value, strategy, col_idx < num_strategies);
87 result.extend_from_slice(redacted_sql.as_bytes());
88
89 if was_redacted {
90 columns_redacted += 1;
91 row_had_redaction = true;
92 }
93 }
94
95 result.extend_from_slice(b")");
96 if row_had_redaction {
97 rows_redacted += 1;
98 }
99 }
100
101 result.extend_from_slice(b";\n");
102
103 Ok((result, rows_redacted, columns_redacted))
104 }
105
106 pub fn rewrite_copy(
108 &mut self,
109 stmt: &[u8],
110 _table_name: &str,
111 table: &TableSchema,
112 strategies: &[StrategyKind],
113 ) -> anyhow::Result<(Vec<u8>, u64, u64)> {
114 let stmt_str = String::from_utf8_lossy(stmt);
118
119 let header_end = stmt_str
121 .find('\n')
122 .ok_or_else(|| anyhow::anyhow!("Invalid COPY statement: no newline"))?;
123 let header = &stmt_str[..header_end];
124 let data_block = &stmt[header_end + 1..];
125
126 let columns = parse_copy_columns(header);
128
129 let mut parser = CopyParser::new(data_block)
131 .with_schema(table)
132 .with_column_order(columns.clone());
133 let rows = parser.parse_rows()?;
134
135 if rows.is_empty() {
136 return Ok((stmt.to_vec(), 0, 0));
137 }
138
139 let mut result = Vec::with_capacity(stmt.len());
141 result.extend_from_slice(header.as_bytes());
142 result.push(b'\n');
143
144 let mut rows_redacted = 0u64;
145 let mut columns_redacted = 0u64;
146
147 for row in &rows {
148 let mut row_had_redaction = false;
149 let mut first = true;
150
151 let values = self.parse_copy_row_values(&row.raw);
153
154 for (col_idx, value) in values.iter().enumerate() {
155 if !first {
156 result.push(b'\t');
157 }
158 first = false;
159
160 let strategy = strategies.get(col_idx).unwrap_or(&StrategyKind::Skip);
161 let (redacted, was_redacted) = self.redact_copy_value(value, strategy);
162 result.extend_from_slice(&redacted);
163
164 if was_redacted {
165 columns_redacted += 1;
166 row_had_redaction = true;
167 }
168 }
169
170 result.push(b'\n');
171 if row_had_redaction {
172 rows_redacted += 1;
173 }
174 }
175
176 result.extend_from_slice(b"\\.\n");
178
179 Ok((result, rows_redacted, columns_redacted))
180 }
181
182 pub fn rewrite_copy_data(
184 &mut self,
185 data_block: &[u8],
186 table: &TableSchema,
187 strategies: &[StrategyKind],
188 columns: &[String],
189 ) -> anyhow::Result<(Vec<u8>, u64, u64)> {
190 let mut parser = CopyParser::new(data_block)
192 .with_schema(table)
193 .with_column_order(columns.to_vec());
194 let rows = parser.parse_rows()?;
195
196 if rows.is_empty() {
197 return Ok((data_block.to_vec(), 0, 0));
198 }
199
200 let mut result = Vec::with_capacity(data_block.len());
202
203 let mut rows_redacted = 0u64;
204 let mut columns_redacted = 0u64;
205
206 for row in &rows {
207 let mut row_had_redaction = false;
208 let mut first = true;
209
210 let values = self.parse_copy_row_values(&row.raw);
212
213 for (col_idx, value) in values.iter().enumerate() {
214 if !first {
215 result.push(b'\t');
216 }
217 first = false;
218
219 let strategy = strategies.get(col_idx).unwrap_or(&StrategyKind::Skip);
220 let (redacted, was_redacted) = self.redact_copy_value(value, strategy);
221 result.extend_from_slice(&redacted);
222
223 if was_redacted {
224 columns_redacted += 1;
225 row_had_redaction = true;
226 }
227 }
228
229 result.push(b'\n');
230 if row_had_redaction {
231 rows_redacted += 1;
232 }
233 }
234
235 result.extend_from_slice(b"\\.\n");
237
238 Ok((result, rows_redacted, columns_redacted))
239 }
240
241 fn parse_copy_row_values(&self, raw: &[u8]) -> Vec<CopyValueRef> {
243 let mut values = Vec::new();
244 let mut start = 0;
245
246 for (i, &b) in raw.iter().enumerate() {
247 if b == b'\t' {
248 values.push(self.parse_single_copy_value(&raw[start..i]));
249 start = i + 1;
250 }
251 }
252 if start <= raw.len() {
254 values.push(self.parse_single_copy_value(&raw[start..]));
255 }
256
257 values
258 }
259
260 fn parse_single_copy_value(&self, raw: &[u8]) -> CopyValueRef {
262 if raw == b"\\N" {
263 CopyValueRef::Null
264 } else {
265 CopyValueRef::Text(raw.to_vec())
266 }
267 }
268
269 fn redact_copy_value(
271 &mut self,
272 value: &CopyValueRef,
273 strategy: &StrategyKind,
274 ) -> (Vec<u8>, bool) {
275 if matches!(strategy, StrategyKind::Skip) {
276 let bytes = match value {
277 CopyValueRef::Null => b"\\N".to_vec(),
278 CopyValueRef::Text(t) => t.clone(),
279 };
280 return (bytes, false);
281 }
282
283 let redact_value = match value {
285 CopyValueRef::Null => RedactValue::Null,
286 CopyValueRef::Text(t) => {
287 let decoded = self.decode_copy_escapes(t);
289 RedactValue::String(String::from_utf8_lossy(&decoded).into_owned())
290 }
291 };
292
293 let result = self.apply_strategy(&redact_value, strategy);
295
296 let bytes = match result {
298 RedactValue::Null => b"\\N".to_vec(),
299 RedactValue::String(s) => self.encode_copy_escapes(&s),
300 RedactValue::Integer(i) => i.to_string().into_bytes(),
301 RedactValue::Bytes(b) => self.encode_copy_escapes(&String::from_utf8_lossy(&b)),
302 };
303
304 (bytes, true)
305 }
306
307 fn decode_copy_escapes(&self, value: &[u8]) -> Vec<u8> {
309 let mut result = Vec::with_capacity(value.len());
310 let mut i = 0;
311
312 while i < value.len() {
313 if value[i] == b'\\' && i + 1 < value.len() {
314 let next = value[i + 1];
315 let decoded = match next {
316 b'n' => b'\n',
317 b'r' => b'\r',
318 b't' => b'\t',
319 b'\\' => b'\\',
320 _ => {
321 result.push(b'\\');
322 result.push(next);
323 i += 2;
324 continue;
325 }
326 };
327 result.push(decoded);
328 i += 2;
329 } else {
330 result.push(value[i]);
331 i += 1;
332 }
333 }
334
335 result
336 }
337
338 fn encode_copy_escapes(&self, value: &str) -> Vec<u8> {
340 let mut result = Vec::with_capacity(value.len());
341
342 for b in value.bytes() {
343 match b {
344 b'\n' => result.extend_from_slice(b"\\n"),
345 b'\r' => result.extend_from_slice(b"\\r"),
346 b'\t' => result.extend_from_slice(b"\\t"),
347 b'\\' => result.extend_from_slice(b"\\\\"),
348 _ => result.push(b),
349 }
350 }
351
352 result
353 }
354
355 fn extract_column_list(&self, stmt: &str) -> Option<Vec<String>> {
357 let upper = stmt.to_uppercase();
358 let values_pos = upper.find("VALUES")?;
359 let before_values = &stmt[..values_pos];
360
361 let close_paren = before_values.rfind(')')?;
363 let open_paren = before_values[..close_paren].rfind('(')?;
364
365 let col_list = &before_values[open_paren + 1..close_paren];
366
367 let upper_cols = col_list.to_uppercase();
369 if col_list.trim().is_empty()
370 || upper_cols.contains("SELECT")
371 || upper_cols.contains("VALUES")
372 {
373 return None;
374 }
375
376 let columns: Vec<String> = col_list
377 .split(',')
378 .map(|c| {
379 c.trim()
380 .trim_matches('`')
381 .trim_matches('"')
382 .trim_matches('[')
383 .trim_matches(']')
384 .to_string()
385 })
386 .collect();
387
388 if columns.is_empty() {
389 None
390 } else {
391 Some(columns)
392 }
393 }
394
395 fn build_insert_header(&self, table_name: &str, columns: &Option<Vec<String>>) -> Vec<u8> {
397 let mut result = Vec::new();
398
399 result.extend_from_slice(b"INSERT INTO ");
401 result.extend_from_slice(self.quote_identifier(table_name).as_bytes());
402
403 if let Some(cols) = columns {
405 result.extend_from_slice(b" (");
406 for (i, col) in cols.iter().enumerate() {
407 if i > 0 {
408 result.extend_from_slice(b", ");
409 }
410 result.extend_from_slice(self.quote_identifier(col).as_bytes());
411 }
412 result.extend_from_slice(b")");
413 }
414
415 result.extend_from_slice(b" VALUES");
416 result
417 }
418
419 fn quote_identifier(&self, name: &str) -> String {
421 match self.dialect {
422 SqlDialect::MySql => format!("`{}`", name),
423 SqlDialect::Postgres | SqlDialect::Sqlite => format!("\"{}\"", name),
424 SqlDialect::Mssql => format!("[{}]", name),
425 }
426 }
427
428 fn redact_value(
430 &mut self,
431 value: &ParsedValue,
432 strategy: &StrategyKind,
433 has_strategy: bool,
434 ) -> (String, bool) {
435 if !has_strategy || matches!(strategy, StrategyKind::Skip) {
437 return (self.format_value(value), false);
438 }
439
440 let redact_value = self.parsed_to_redact(value);
442
443 let result = self.apply_strategy(&redact_value, strategy);
445
446 (self.format_redact_value(&result), true)
448 }
449
450 fn parsed_to_redact(&self, value: &ParsedValue) -> RedactValue {
452 match value {
453 ParsedValue::Null => RedactValue::Null,
454 ParsedValue::Integer(n) => RedactValue::Integer(*n),
455 ParsedValue::BigInteger(n) => RedactValue::Integer(*n as i64), ParsedValue::String { value } => RedactValue::String(value.clone()),
457 ParsedValue::Hex(bytes) => RedactValue::Bytes(bytes.clone()),
458 ParsedValue::Other(bytes) => {
459 RedactValue::String(String::from_utf8_lossy(bytes).into_owned())
460 }
461 }
462 }
463
464 fn apply_strategy(&mut self, value: &RedactValue, strategy: &StrategyKind) -> RedactValue {
466 match strategy {
467 StrategyKind::Null => NullStrategy::new().apply(value, &mut self.rng),
468 StrategyKind::Constant { value: constant } => {
469 ConstantStrategy::new(constant.clone()).apply(value, &mut self.rng)
470 }
471 StrategyKind::Hash { preserve_domain } => {
472 HashStrategy::new(*preserve_domain).apply(value, &mut self.rng)
473 }
474 StrategyKind::Mask { pattern } => {
475 MaskStrategy::new(pattern.clone()).apply(value, &mut self.rng)
476 }
477 StrategyKind::Fake { generator } => {
478 FakeStrategy::new(generator.clone(), self.locale.clone())
479 .apply(value, &mut self.rng)
480 }
481 StrategyKind::Shuffle => {
482 value.clone()
485 }
486 StrategyKind::Skip => value.clone(),
487 }
488 }
489
490 fn format_value(&self, value: &ParsedValue) -> String {
492 match value {
493 ParsedValue::Null => "NULL".to_string(),
494 ParsedValue::Integer(n) => n.to_string(),
495 ParsedValue::BigInteger(n) => n.to_string(),
496 ParsedValue::String { value } => self.format_sql_string(value),
497 ParsedValue::Hex(bytes) => String::from_utf8_lossy(bytes).into_owned(),
498 ParsedValue::Other(bytes) => String::from_utf8_lossy(bytes).into_owned(),
499 }
500 }
501
502 fn format_redact_value(&self, value: &RedactValue) -> String {
504 match value {
505 RedactValue::Null => "NULL".to_string(),
506 RedactValue::Integer(n) => n.to_string(),
507 RedactValue::String(s) => self.format_sql_string(s),
508 RedactValue::Bytes(b) => {
509 format!("0x{}", hex::encode(b))
511 }
512 }
513 }
514
515 fn format_sql_string(&self, value: &str) -> String {
517 match self.dialect {
518 SqlDialect::MySql => {
519 let escaped = value
521 .replace('\\', "\\\\")
522 .replace('\'', "\\'")
523 .replace('\n', "\\n")
524 .replace('\r', "\\r")
525 .replace('\t', "\\t")
526 .replace('\0', "\\0");
527 format!("'{}'", escaped)
528 }
529 SqlDialect::Postgres | SqlDialect::Sqlite => {
530 let escaped = value.replace('\'', "''");
532 format!("'{}'", escaped)
533 }
534 SqlDialect::Mssql => {
535 let escaped = value.replace('\'', "''");
537 if value.bytes().any(|b| b > 127) {
539 format!("N'{}'", escaped)
540 } else {
541 format!("'{}'", escaped)
542 }
543 }
544 }
545 }
546}
547
548enum CopyValueRef {
550 Null,
551 Text(Vec<u8>),
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557 use crate::schema::{Column, ColumnId, ColumnType, TableId, TableSchema};
558
559 fn create_test_schema() -> TableSchema {
560 TableSchema {
561 name: "users".to_string(),
562 id: TableId(0),
563 columns: vec![
564 Column {
565 name: "id".to_string(),
566 col_type: ColumnType::Int,
567 ordinal: ColumnId(0),
568 is_primary_key: true,
569 is_nullable: false,
570 },
571 Column {
572 name: "email".to_string(),
573 col_type: ColumnType::Text,
574 ordinal: ColumnId(1),
575 is_primary_key: false,
576 is_nullable: false,
577 },
578 Column {
579 name: "name".to_string(),
580 col_type: ColumnType::Text,
581 ordinal: ColumnId(2),
582 is_primary_key: false,
583 is_nullable: true,
584 },
585 ],
586 primary_key: vec![ColumnId(0)],
587 foreign_keys: vec![],
588 indexes: vec![],
589 create_statement: None,
590 }
591 }
592
593 #[test]
594 fn test_rewrite_insert_mysql() {
595 let mut rewriter = ValueRewriter::new(Some(42), SqlDialect::MySql, "en".to_string());
596 let schema = create_test_schema();
597
598 let stmt = b"INSERT INTO `users` (`id`, `email`, `name`) VALUES (1, 'alice@example.com', 'Alice');";
599 let strategies = vec![
600 StrategyKind::Skip, StrategyKind::Hash {
602 preserve_domain: true,
603 }, StrategyKind::Fake {
605 generator: "name".to_string(),
606 }, ];
608
609 let (result, rows, cols) = rewriter
610 .rewrite_insert(stmt, "users", &schema, &strategies)
611 .unwrap();
612 let result_str = String::from_utf8_lossy(&result);
613
614 assert!(result_str.contains("INSERT INTO `users`"));
615 assert!(result_str.contains("VALUES"));
616 assert_eq!(rows, 1);
617 assert_eq!(cols, 2); }
619
620 #[test]
621 fn test_rewrite_insert_mssql() {
622 let mut rewriter = ValueRewriter::new(Some(42), SqlDialect::Mssql, "en".to_string());
623 let schema = create_test_schema();
624
625 let stmt = b"INSERT INTO [users] ([id], [email], [name]) VALUES (1, N'alice@example.com', N'Alice');";
626 let strategies = vec![
627 StrategyKind::Skip, StrategyKind::Null, StrategyKind::Skip, ];
631
632 let (result, rows, cols) = rewriter
633 .rewrite_insert(stmt, "users", &schema, &strategies)
634 .unwrap();
635 let result_str = String::from_utf8_lossy(&result);
636
637 assert!(result_str.contains("INSERT INTO [users]"));
638 assert!(result_str.contains("NULL")); assert_eq!(rows, 1);
640 assert_eq!(cols, 1);
641 }
642
643 #[test]
644 fn test_format_sql_string_mysql() {
645 let rewriter = ValueRewriter::new(Some(42), SqlDialect::MySql, "en".to_string());
646 assert_eq!(rewriter.format_sql_string("hello"), "'hello'");
647 assert_eq!(rewriter.format_sql_string("it's"), "'it\\'s'");
648 assert_eq!(rewriter.format_sql_string("line\nbreak"), "'line\\nbreak'");
649 }
650
651 #[test]
652 fn test_format_sql_string_postgres() {
653 let rewriter = ValueRewriter::new(Some(42), SqlDialect::Postgres, "en".to_string());
654 assert_eq!(rewriter.format_sql_string("hello"), "'hello'");
655 assert_eq!(rewriter.format_sql_string("it's"), "'it''s'");
656 }
657
658 #[test]
659 fn test_format_sql_string_mssql() {
660 let rewriter = ValueRewriter::new(Some(42), SqlDialect::Mssql, "en".to_string());
661 assert_eq!(rewriter.format_sql_string("hello"), "'hello'");
662 assert_eq!(rewriter.format_sql_string("café"), "N'café'");
663 }
664
665 #[test]
666 fn test_quote_identifier() {
667 let mysql = ValueRewriter::new(None, SqlDialect::MySql, "en".to_string());
668 assert_eq!(mysql.quote_identifier("users"), "`users`");
669
670 let pg = ValueRewriter::new(None, SqlDialect::Postgres, "en".to_string());
671 assert_eq!(pg.quote_identifier("users"), "\"users\"");
672
673 let mssql = ValueRewriter::new(None, SqlDialect::Mssql, "en".to_string());
674 assert_eq!(mssql.quote_identifier("users"), "[users]");
675 }
676}