1use crate::parser::mysql_insert::{InsertParser, ParsedValue};
7use crate::parser::postgres_copy::{parse_copy_columns, CopyParser};
8use crate::parser::SqlDialect;
9use crate::redactor::strategy::{
10 ConstantStrategy, FakeStrategy, HashStrategy, MaskStrategy, NullStrategy, RedactValue,
11 Strategy, StrategyKind,
12};
13use crate::schema::TableSchema;
14use rand::rngs::StdRng;
15use rand::SeedableRng;
16
17pub struct ValueRewriter {
19 rng: StdRng,
21 dialect: SqlDialect,
23 locale: String,
25}
26
27impl ValueRewriter {
28 pub fn new(seed: Option<u64>, dialect: SqlDialect, locale: String) -> Self {
30 let rng = match seed {
31 Some(s) => StdRng::seed_from_u64(s),
32 None => StdRng::from_entropy(),
33 };
34 Self { rng, dialect, locale }
35 }
36
37 pub fn rewrite_insert(
39 &mut self,
40 stmt: &[u8],
41 table_name: &str,
42 table: &TableSchema,
43 strategies: &[StrategyKind],
44 ) -> anyhow::Result<(Vec<u8>, u64, u64)> {
45 let mut parser = InsertParser::new(stmt).with_schema(table);
47 let rows = parser.parse_rows()?;
48
49 if rows.is_empty() {
50 return Ok((stmt.to_vec(), 0, 0));
51 }
52
53 let stmt_str = String::from_utf8_lossy(stmt);
55 let column_list = self.extract_column_list(&stmt_str);
56
57 let mut result = self.build_insert_header(table_name, &column_list);
59
60 let mut rows_redacted = 0u64;
61 let mut columns_redacted = 0u64;
62 let num_strategies = strategies.len();
63
64 for (row_idx, row) in rows.iter().enumerate() {
65 if row_idx > 0 {
66 result.extend_from_slice(b",");
67 }
68 result.extend_from_slice(b"\n(");
69
70 let mut row_had_redaction = false;
71
72 for (col_idx, value) in row.values.iter().enumerate() {
73 if col_idx > 0 {
74 result.extend_from_slice(b", ");
75 }
76
77 let strategy = strategies.get(col_idx).unwrap_or(&StrategyKind::Skip);
79
80 let (redacted_sql, was_redacted) =
82 self.redact_value(value, strategy, col_idx < num_strategies);
83 result.extend_from_slice(redacted_sql.as_bytes());
84
85 if was_redacted {
86 columns_redacted += 1;
87 row_had_redaction = true;
88 }
89 }
90
91 result.extend_from_slice(b")");
92 if row_had_redaction {
93 rows_redacted += 1;
94 }
95 }
96
97 result.extend_from_slice(b";\n");
98
99 Ok((result, rows_redacted, columns_redacted))
100 }
101
102 pub fn rewrite_copy(
104 &mut self,
105 stmt: &[u8],
106 _table_name: &str,
107 table: &TableSchema,
108 strategies: &[StrategyKind],
109 ) -> anyhow::Result<(Vec<u8>, u64, u64)> {
110 let stmt_str = String::from_utf8_lossy(stmt);
114
115 let header_end = stmt_str
117 .find('\n')
118 .ok_or_else(|| anyhow::anyhow!("Invalid COPY statement: no newline"))?;
119 let header = &stmt_str[..header_end];
120 let data_block = &stmt[header_end + 1..];
121
122 let columns = parse_copy_columns(header);
124
125 let mut parser = CopyParser::new(data_block)
127 .with_schema(table)
128 .with_column_order(columns.clone());
129 let rows = parser.parse_rows()?;
130
131 if rows.is_empty() {
132 return Ok((stmt.to_vec(), 0, 0));
133 }
134
135 let mut result = Vec::with_capacity(stmt.len());
137 result.extend_from_slice(header.as_bytes());
138 result.push(b'\n');
139
140 let mut rows_redacted = 0u64;
141 let mut columns_redacted = 0u64;
142
143 for row in &rows {
144 let mut row_had_redaction = false;
145 let mut first = true;
146
147 let values = self.parse_copy_row_values(&row.raw);
149
150 for (col_idx, value) in values.iter().enumerate() {
151 if !first {
152 result.push(b'\t');
153 }
154 first = false;
155
156 let strategy = strategies.get(col_idx).unwrap_or(&StrategyKind::Skip);
157 let (redacted, was_redacted) = self.redact_copy_value(value, strategy);
158 result.extend_from_slice(&redacted);
159
160 if was_redacted {
161 columns_redacted += 1;
162 row_had_redaction = true;
163 }
164 }
165
166 result.push(b'\n');
167 if row_had_redaction {
168 rows_redacted += 1;
169 }
170 }
171
172 result.extend_from_slice(b"\\.\n");
174
175 Ok((result, rows_redacted, columns_redacted))
176 }
177
178 pub fn rewrite_copy_data(
180 &mut self,
181 data_block: &[u8],
182 table: &TableSchema,
183 strategies: &[StrategyKind],
184 columns: &[String],
185 ) -> anyhow::Result<(Vec<u8>, u64, u64)> {
186 let mut parser = CopyParser::new(data_block)
188 .with_schema(table)
189 .with_column_order(columns.to_vec());
190 let rows = parser.parse_rows()?;
191
192 if rows.is_empty() {
193 return Ok((data_block.to_vec(), 0, 0));
194 }
195
196 let mut result = Vec::with_capacity(data_block.len());
198
199 let mut rows_redacted = 0u64;
200 let mut columns_redacted = 0u64;
201
202 for row in &rows {
203 let mut row_had_redaction = false;
204 let mut first = true;
205
206 let values = self.parse_copy_row_values(&row.raw);
208
209 for (col_idx, value) in values.iter().enumerate() {
210 if !first {
211 result.push(b'\t');
212 }
213 first = false;
214
215 let strategy = strategies.get(col_idx).unwrap_or(&StrategyKind::Skip);
216 let (redacted, was_redacted) = self.redact_copy_value(value, strategy);
217 result.extend_from_slice(&redacted);
218
219 if was_redacted {
220 columns_redacted += 1;
221 row_had_redaction = true;
222 }
223 }
224
225 result.push(b'\n');
226 if row_had_redaction {
227 rows_redacted += 1;
228 }
229 }
230
231 result.extend_from_slice(b"\\.\n");
233
234 Ok((result, rows_redacted, columns_redacted))
235 }
236
237 fn parse_copy_row_values(&self, raw: &[u8]) -> Vec<CopyValueRef> {
239 let mut values = Vec::new();
240 let mut start = 0;
241
242 for (i, &b) in raw.iter().enumerate() {
243 if b == b'\t' {
244 values.push(self.parse_single_copy_value(&raw[start..i]));
245 start = i + 1;
246 }
247 }
248 if start <= raw.len() {
250 values.push(self.parse_single_copy_value(&raw[start..]));
251 }
252
253 values
254 }
255
256 fn parse_single_copy_value(&self, raw: &[u8]) -> CopyValueRef {
258 if raw == b"\\N" {
259 CopyValueRef::Null
260 } else {
261 CopyValueRef::Text(raw.to_vec())
262 }
263 }
264
265 fn redact_copy_value(&mut self, value: &CopyValueRef, strategy: &StrategyKind) -> (Vec<u8>, bool) {
267 if matches!(strategy, StrategyKind::Skip) {
268 let bytes = match value {
269 CopyValueRef::Null => b"\\N".to_vec(),
270 CopyValueRef::Text(t) => t.clone(),
271 };
272 return (bytes, false);
273 }
274
275 let redact_value = match value {
277 CopyValueRef::Null => RedactValue::Null,
278 CopyValueRef::Text(t) => {
279 let decoded = self.decode_copy_escapes(t);
281 RedactValue::String(String::from_utf8_lossy(&decoded).into_owned())
282 }
283 };
284
285 let result = self.apply_strategy(&redact_value, strategy);
287
288 let bytes = match result {
290 RedactValue::Null => b"\\N".to_vec(),
291 RedactValue::String(s) => self.encode_copy_escapes(&s),
292 RedactValue::Integer(i) => i.to_string().into_bytes(),
293 RedactValue::Bytes(b) => self.encode_copy_escapes(&String::from_utf8_lossy(&b)),
294 };
295
296 (bytes, true)
297 }
298
299 fn decode_copy_escapes(&self, value: &[u8]) -> Vec<u8> {
301 let mut result = Vec::with_capacity(value.len());
302 let mut i = 0;
303
304 while i < value.len() {
305 if value[i] == b'\\' && i + 1 < value.len() {
306 let next = value[i + 1];
307 let decoded = match next {
308 b'n' => b'\n',
309 b'r' => b'\r',
310 b't' => b'\t',
311 b'\\' => b'\\',
312 _ => {
313 result.push(b'\\');
314 result.push(next);
315 i += 2;
316 continue;
317 }
318 };
319 result.push(decoded);
320 i += 2;
321 } else {
322 result.push(value[i]);
323 i += 1;
324 }
325 }
326
327 result
328 }
329
330 fn encode_copy_escapes(&self, value: &str) -> Vec<u8> {
332 let mut result = Vec::with_capacity(value.len());
333
334 for b in value.bytes() {
335 match b {
336 b'\n' => result.extend_from_slice(b"\\n"),
337 b'\r' => result.extend_from_slice(b"\\r"),
338 b'\t' => result.extend_from_slice(b"\\t"),
339 b'\\' => result.extend_from_slice(b"\\\\"),
340 _ => result.push(b),
341 }
342 }
343
344 result
345 }
346
347 fn extract_column_list(&self, stmt: &str) -> Option<Vec<String>> {
349 let upper = stmt.to_uppercase();
350 let values_pos = upper.find("VALUES")?;
351 let before_values = &stmt[..values_pos];
352
353 let close_paren = before_values.rfind(')')?;
355 let open_paren = before_values[..close_paren].rfind('(')?;
356
357 let col_list = &before_values[open_paren + 1..close_paren];
358
359 let upper_cols = col_list.to_uppercase();
361 if col_list.trim().is_empty()
362 || upper_cols.contains("SELECT")
363 || upper_cols.contains("VALUES")
364 {
365 return None;
366 }
367
368 let columns: Vec<String> = col_list
369 .split(',')
370 .map(|c| {
371 c.trim()
372 .trim_matches('`')
373 .trim_matches('"')
374 .trim_matches('[')
375 .trim_matches(']')
376 .to_string()
377 })
378 .collect();
379
380 if columns.is_empty() {
381 None
382 } else {
383 Some(columns)
384 }
385 }
386
387 fn build_insert_header(&self, table_name: &str, columns: &Option<Vec<String>>) -> Vec<u8> {
389 let mut result = Vec::new();
390
391 result.extend_from_slice(b"INSERT INTO ");
393 result.extend_from_slice(self.quote_identifier(table_name).as_bytes());
394
395 if let Some(cols) = columns {
397 result.extend_from_slice(b" (");
398 for (i, col) in cols.iter().enumerate() {
399 if i > 0 {
400 result.extend_from_slice(b", ");
401 }
402 result.extend_from_slice(self.quote_identifier(col).as_bytes());
403 }
404 result.extend_from_slice(b")");
405 }
406
407 result.extend_from_slice(b" VALUES");
408 result
409 }
410
411 fn quote_identifier(&self, name: &str) -> String {
413 match self.dialect {
414 SqlDialect::MySql => format!("`{}`", name),
415 SqlDialect::Postgres | SqlDialect::Sqlite => format!("\"{}\"", name),
416 SqlDialect::Mssql => format!("[{}]", name),
417 }
418 }
419
420 fn redact_value(
422 &mut self,
423 value: &ParsedValue,
424 strategy: &StrategyKind,
425 has_strategy: bool,
426 ) -> (String, bool) {
427 if !has_strategy || matches!(strategy, StrategyKind::Skip) {
429 return (self.format_value(value), false);
430 }
431
432 let redact_value = self.parsed_to_redact(value);
434
435 let result = self.apply_strategy(&redact_value, strategy);
437
438 (self.format_redact_value(&result), true)
440 }
441
442 fn parsed_to_redact(&self, value: &ParsedValue) -> RedactValue {
444 match value {
445 ParsedValue::Null => RedactValue::Null,
446 ParsedValue::Integer(n) => RedactValue::Integer(*n),
447 ParsedValue::BigInteger(n) => RedactValue::Integer(*n as i64), ParsedValue::String { value } => RedactValue::String(value.clone()),
449 ParsedValue::Hex(bytes) => RedactValue::Bytes(bytes.clone()),
450 ParsedValue::Other(bytes) => {
451 RedactValue::String(String::from_utf8_lossy(bytes).into_owned())
452 }
453 }
454 }
455
456 fn apply_strategy(&mut self, value: &RedactValue, strategy: &StrategyKind) -> RedactValue {
458 match strategy {
459 StrategyKind::Null => NullStrategy::new().apply(value, &mut self.rng),
460 StrategyKind::Constant { value: constant } => {
461 ConstantStrategy::new(constant.clone()).apply(value, &mut self.rng)
462 }
463 StrategyKind::Hash { preserve_domain } => {
464 HashStrategy::new(*preserve_domain).apply(value, &mut self.rng)
465 }
466 StrategyKind::Mask { pattern } => {
467 MaskStrategy::new(pattern.clone()).apply(value, &mut self.rng)
468 }
469 StrategyKind::Fake { generator } => {
470 FakeStrategy::new(generator.clone(), self.locale.clone()).apply(value, &mut self.rng)
471 }
472 StrategyKind::Shuffle => {
473 value.clone()
476 }
477 StrategyKind::Skip => value.clone(),
478 }
479 }
480
481 fn format_value(&self, value: &ParsedValue) -> String {
483 match value {
484 ParsedValue::Null => "NULL".to_string(),
485 ParsedValue::Integer(n) => n.to_string(),
486 ParsedValue::BigInteger(n) => n.to_string(),
487 ParsedValue::String { value } => self.format_sql_string(value),
488 ParsedValue::Hex(bytes) => String::from_utf8_lossy(bytes).into_owned(),
489 ParsedValue::Other(bytes) => String::from_utf8_lossy(bytes).into_owned(),
490 }
491 }
492
493 fn format_redact_value(&self, value: &RedactValue) -> String {
495 match value {
496 RedactValue::Null => "NULL".to_string(),
497 RedactValue::Integer(n) => n.to_string(),
498 RedactValue::String(s) => self.format_sql_string(s),
499 RedactValue::Bytes(b) => {
500 format!("0x{}", hex::encode(b))
502 }
503 }
504 }
505
506 fn format_sql_string(&self, value: &str) -> String {
508 match self.dialect {
509 SqlDialect::MySql => {
510 let escaped = value
512 .replace('\\', "\\\\")
513 .replace('\'', "\\'")
514 .replace('\n', "\\n")
515 .replace('\r', "\\r")
516 .replace('\t', "\\t")
517 .replace('\0', "\\0");
518 format!("'{}'", escaped)
519 }
520 SqlDialect::Postgres | SqlDialect::Sqlite => {
521 let escaped = value.replace('\'', "''");
523 format!("'{}'", escaped)
524 }
525 SqlDialect::Mssql => {
526 let escaped = value.replace('\'', "''");
528 if value.bytes().any(|b| b > 127) {
530 format!("N'{}'", escaped)
531 } else {
532 format!("'{}'", escaped)
533 }
534 }
535 }
536 }
537}
538
539enum CopyValueRef {
541 Null,
542 Text(Vec<u8>),
543}
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548 use crate::schema::{Column, ColumnId, ColumnType, TableId, TableSchema};
549
550 fn create_test_schema() -> TableSchema {
551 TableSchema {
552 name: "users".to_string(),
553 id: TableId(0),
554 columns: vec![
555 Column {
556 name: "id".to_string(),
557 col_type: ColumnType::Int,
558 ordinal: ColumnId(0),
559 is_primary_key: true,
560 is_nullable: false,
561 },
562 Column {
563 name: "email".to_string(),
564 col_type: ColumnType::Text,
565 ordinal: ColumnId(1),
566 is_primary_key: false,
567 is_nullable: false,
568 },
569 Column {
570 name: "name".to_string(),
571 col_type: ColumnType::Text,
572 ordinal: ColumnId(2),
573 is_primary_key: false,
574 is_nullable: true,
575 },
576 ],
577 primary_key: vec![ColumnId(0)],
578 foreign_keys: vec![],
579 indexes: vec![],
580 create_statement: None,
581 }
582 }
583
584 #[test]
585 fn test_rewrite_insert_mysql() {
586 let mut rewriter = ValueRewriter::new(Some(42), SqlDialect::MySql, "en".to_string());
587 let schema = create_test_schema();
588
589 let stmt = b"INSERT INTO `users` (`id`, `email`, `name`) VALUES (1, 'alice@example.com', 'Alice');";
590 let strategies = vec![
591 StrategyKind::Skip, StrategyKind::Hash { preserve_domain: true }, StrategyKind::Fake { generator: "name".to_string() }, ];
595
596 let (result, rows, cols) = rewriter.rewrite_insert(stmt, "users", &schema, &strategies).unwrap();
597 let result_str = String::from_utf8_lossy(&result);
598
599 assert!(result_str.contains("INSERT INTO `users`"));
600 assert!(result_str.contains("VALUES"));
601 assert_eq!(rows, 1);
602 assert_eq!(cols, 2); }
604
605 #[test]
606 fn test_rewrite_insert_mssql() {
607 let mut rewriter = ValueRewriter::new(Some(42), SqlDialect::Mssql, "en".to_string());
608 let schema = create_test_schema();
609
610 let stmt = b"INSERT INTO [users] ([id], [email], [name]) VALUES (1, N'alice@example.com', N'Alice');";
611 let strategies = vec![
612 StrategyKind::Skip, StrategyKind::Null, StrategyKind::Skip, ];
616
617 let (result, rows, cols) = rewriter.rewrite_insert(stmt, "users", &schema, &strategies).unwrap();
618 let result_str = String::from_utf8_lossy(&result);
619
620 assert!(result_str.contains("INSERT INTO [users]"));
621 assert!(result_str.contains("NULL")); assert_eq!(rows, 1);
623 assert_eq!(cols, 1);
624 }
625
626 #[test]
627 fn test_format_sql_string_mysql() {
628 let rewriter = ValueRewriter::new(Some(42), SqlDialect::MySql, "en".to_string());
629 assert_eq!(rewriter.format_sql_string("hello"), "'hello'");
630 assert_eq!(rewriter.format_sql_string("it's"), "'it\\'s'");
631 assert_eq!(rewriter.format_sql_string("line\nbreak"), "'line\\nbreak'");
632 }
633
634 #[test]
635 fn test_format_sql_string_postgres() {
636 let rewriter = ValueRewriter::new(Some(42), SqlDialect::Postgres, "en".to_string());
637 assert_eq!(rewriter.format_sql_string("hello"), "'hello'");
638 assert_eq!(rewriter.format_sql_string("it's"), "'it''s'");
639 }
640
641 #[test]
642 fn test_format_sql_string_mssql() {
643 let rewriter = ValueRewriter::new(Some(42), SqlDialect::Mssql, "en".to_string());
644 assert_eq!(rewriter.format_sql_string("hello"), "'hello'");
645 assert_eq!(rewriter.format_sql_string("café"), "N'café'");
646 }
647
648 #[test]
649 fn test_quote_identifier() {
650 let mysql = ValueRewriter::new(None, SqlDialect::MySql, "en".to_string());
651 assert_eq!(mysql.quote_identifier("users"), "`users`");
652
653 let pg = ValueRewriter::new(None, SqlDialect::Postgres, "en".to_string());
654 assert_eq!(pg.quote_identifier("users"), "\"users\"");
655
656 let mssql = ValueRewriter::new(None, SqlDialect::Mssql, "en".to_string());
657 assert_eq!(mssql.quote_identifier("users"), "[users]");
658 }
659}