sql_splitter/redactor/
rewriter.rs

1//! Value rewriter for INSERT and COPY statement redaction.
2//!
3//! Handles parsing values, applying redaction strategies, and formatting
4//! the redacted values back to SQL with proper dialect-aware escaping.
5
6use crate::parser::mysql_insert::{InsertParser, ParsedValue};
7use crate::parser::postgres_copy::{parse_copy_columns, CopyParser};
8use crate::parser::SqlDialect;
9use crate::redactor::strategy::{
10    ConstantStrategy, FakeStrategy, HashStrategy, MaskStrategy, NullStrategy, RedactValue,
11    Strategy, StrategyKind,
12};
13use crate::schema::TableSchema;
14use rand::rngs::StdRng;
15use rand::SeedableRng;
16
17/// Rewriter for INSERT and COPY statements
18pub struct ValueRewriter {
19    /// RNG for reproducible redaction
20    rng: StdRng,
21    /// Dialect for output formatting
22    dialect: SqlDialect,
23    /// Locale for fake data generation
24    locale: String,
25}
26
27impl ValueRewriter {
28    /// Create a new rewriter with optional seed for reproducibility
29    pub fn new(seed: Option<u64>, dialect: SqlDialect, locale: String) -> Self {
30        let rng = match seed {
31            Some(s) => StdRng::seed_from_u64(s),
32            None => StdRng::from_entropy(),
33        };
34        Self { rng, dialect, locale }
35    }
36
37    /// Rewrite an INSERT statement with redacted values
38    pub fn rewrite_insert(
39        &mut self,
40        stmt: &[u8],
41        table_name: &str,
42        table: &TableSchema,
43        strategies: &[StrategyKind],
44    ) -> anyhow::Result<(Vec<u8>, u64, u64)> {
45        // Parse the INSERT statement
46        let mut parser = InsertParser::new(stmt).with_schema(table);
47        let rows = parser.parse_rows()?;
48
49        if rows.is_empty() {
50            return Ok((stmt.to_vec(), 0, 0));
51        }
52
53        // Get the column list (if any) from the statement
54        let stmt_str = String::from_utf8_lossy(stmt);
55        let column_list = self.extract_column_list(&stmt_str);
56
57        // Build the header: INSERT INTO table_name (columns) VALUES
58        let mut result = self.build_insert_header(table_name, &column_list);
59
60        let mut rows_redacted = 0u64;
61        let mut columns_redacted = 0u64;
62        let num_strategies = strategies.len();
63
64        for (row_idx, row) in rows.iter().enumerate() {
65            if row_idx > 0 {
66                result.extend_from_slice(b",");
67            }
68            result.extend_from_slice(b"\n(");
69
70            let mut row_had_redaction = false;
71
72            for (col_idx, value) in row.values.iter().enumerate() {
73                if col_idx > 0 {
74                    result.extend_from_slice(b", ");
75                }
76
77                // Get strategy for this column (may be Skip if index out of bounds)
78                let strategy = strategies.get(col_idx).unwrap_or(&StrategyKind::Skip);
79
80                // Apply redaction
81                let (redacted_sql, was_redacted) =
82                    self.redact_value(value, strategy, col_idx < num_strategies);
83                result.extend_from_slice(redacted_sql.as_bytes());
84
85                if was_redacted {
86                    columns_redacted += 1;
87                    row_had_redaction = true;
88                }
89            }
90
91            result.extend_from_slice(b")");
92            if row_had_redaction {
93                rows_redacted += 1;
94            }
95        }
96
97        result.extend_from_slice(b";\n");
98
99        Ok((result, rows_redacted, columns_redacted))
100    }
101
102    /// Rewrite a COPY statement with redacted values (PostgreSQL)
103    pub fn rewrite_copy(
104        &mut self,
105        stmt: &[u8],
106        _table_name: &str,
107        table: &TableSchema,
108        strategies: &[StrategyKind],
109    ) -> anyhow::Result<(Vec<u8>, u64, u64)> {
110        // COPY statements include the header and data block
111        // Format: COPY table (cols) FROM stdin;\ndata\n\.\n
112
113        let stmt_str = String::from_utf8_lossy(stmt);
114
115        // Find the header line (ends with "FROM stdin;" or similar)
116        let header_end = stmt_str
117            .find('\n')
118            .ok_or_else(|| anyhow::anyhow!("Invalid COPY statement: no newline"))?;
119        let header = &stmt_str[..header_end];
120        let data_block = &stmt[header_end + 1..];
121
122        // Parse column list from header
123        let columns = parse_copy_columns(header);
124
125        // Parse data rows
126        let mut parser = CopyParser::new(data_block)
127            .with_schema(table)
128            .with_column_order(columns.clone());
129        let rows = parser.parse_rows()?;
130
131        if rows.is_empty() {
132            return Ok((stmt.to_vec(), 0, 0));
133        }
134
135        // Build result: header + redacted data + terminator
136        let mut result = Vec::with_capacity(stmt.len());
137        result.extend_from_slice(header.as_bytes());
138        result.push(b'\n');
139
140        let mut rows_redacted = 0u64;
141        let mut columns_redacted = 0u64;
142
143        for row in &rows {
144            let mut row_had_redaction = false;
145            let mut first = true;
146
147            // Parse the raw values from the row
148            let values = self.parse_copy_row_values(&row.raw);
149
150            for (col_idx, value) in values.iter().enumerate() {
151                if !first {
152                    result.push(b'\t');
153                }
154                first = false;
155
156                let strategy = strategies.get(col_idx).unwrap_or(&StrategyKind::Skip);
157                let (redacted, was_redacted) = self.redact_copy_value(value, strategy);
158                result.extend_from_slice(&redacted);
159
160                if was_redacted {
161                    columns_redacted += 1;
162                    row_had_redaction = true;
163                }
164            }
165
166            result.push(b'\n');
167            if row_had_redaction {
168                rows_redacted += 1;
169            }
170        }
171
172        // Add terminator
173        result.extend_from_slice(b"\\.\n");
174
175        Ok((result, rows_redacted, columns_redacted))
176    }
177
178    /// Rewrite just the COPY data block (header handled separately)
179    pub fn rewrite_copy_data(
180        &mut self,
181        data_block: &[u8],
182        table: &TableSchema,
183        strategies: &[StrategyKind],
184        columns: &[String],
185    ) -> anyhow::Result<(Vec<u8>, u64, u64)> {
186        // Parse data rows
187        let mut parser = CopyParser::new(data_block)
188            .with_schema(table)
189            .with_column_order(columns.to_vec());
190        let rows = parser.parse_rows()?;
191
192        if rows.is_empty() {
193            return Ok((data_block.to_vec(), 0, 0));
194        }
195
196        // Build result: redacted data + terminator
197        let mut result = Vec::with_capacity(data_block.len());
198
199        let mut rows_redacted = 0u64;
200        let mut columns_redacted = 0u64;
201
202        for row in &rows {
203            let mut row_had_redaction = false;
204            let mut first = true;
205
206            // Parse the raw values from the row
207            let values = self.parse_copy_row_values(&row.raw);
208
209            for (col_idx, value) in values.iter().enumerate() {
210                if !first {
211                    result.push(b'\t');
212                }
213                first = false;
214
215                let strategy = strategies.get(col_idx).unwrap_or(&StrategyKind::Skip);
216                let (redacted, was_redacted) = self.redact_copy_value(value, strategy);
217                result.extend_from_slice(&redacted);
218
219                if was_redacted {
220                    columns_redacted += 1;
221                    row_had_redaction = true;
222                }
223            }
224
225            result.push(b'\n');
226            if row_had_redaction {
227                rows_redacted += 1;
228            }
229        }
230
231        // Add terminator
232        result.extend_from_slice(b"\\.\n");
233
234        Ok((result, rows_redacted, columns_redacted))
235    }
236
237    /// Parse tab-separated values from a COPY row
238    fn parse_copy_row_values(&self, raw: &[u8]) -> Vec<CopyValueRef> {
239        let mut values = Vec::new();
240        let mut start = 0;
241
242        for (i, &b) in raw.iter().enumerate() {
243            if b == b'\t' {
244                values.push(self.parse_single_copy_value(&raw[start..i]));
245                start = i + 1;
246            }
247        }
248        // Last value
249        if start <= raw.len() {
250            values.push(self.parse_single_copy_value(&raw[start..]));
251        }
252
253        values
254    }
255
256    /// Parse a single COPY value
257    fn parse_single_copy_value(&self, raw: &[u8]) -> CopyValueRef {
258        if raw == b"\\N" {
259            CopyValueRef::Null
260        } else {
261            CopyValueRef::Text(raw.to_vec())
262        }
263    }
264
265    /// Redact a COPY value and return the redacted bytes
266    fn redact_copy_value(&mut self, value: &CopyValueRef, strategy: &StrategyKind) -> (Vec<u8>, bool) {
267        if matches!(strategy, StrategyKind::Skip) {
268            let bytes = match value {
269                CopyValueRef::Null => b"\\N".to_vec(),
270                CopyValueRef::Text(t) => t.clone(),
271            };
272            return (bytes, false);
273        }
274
275        // Convert to RedactValue
276        let redact_value = match value {
277            CopyValueRef::Null => RedactValue::Null,
278            CopyValueRef::Text(t) => {
279                // Decode escape sequences first
280                let decoded = self.decode_copy_escapes(t);
281                RedactValue::String(String::from_utf8_lossy(&decoded).into_owned())
282            }
283        };
284
285        // Apply strategy
286        let result = self.apply_strategy(&redact_value, strategy);
287
288        // Convert back to COPY format
289        let bytes = match result {
290            RedactValue::Null => b"\\N".to_vec(),
291            RedactValue::String(s) => self.encode_copy_escapes(&s),
292            RedactValue::Integer(i) => i.to_string().into_bytes(),
293            RedactValue::Bytes(b) => self.encode_copy_escapes(&String::from_utf8_lossy(&b)),
294        };
295
296        (bytes, true)
297    }
298
299    /// Decode PostgreSQL COPY escape sequences
300    fn decode_copy_escapes(&self, value: &[u8]) -> Vec<u8> {
301        let mut result = Vec::with_capacity(value.len());
302        let mut i = 0;
303
304        while i < value.len() {
305            if value[i] == b'\\' && i + 1 < value.len() {
306                let next = value[i + 1];
307                let decoded = match next {
308                    b'n' => b'\n',
309                    b'r' => b'\r',
310                    b't' => b'\t',
311                    b'\\' => b'\\',
312                    _ => {
313                        result.push(b'\\');
314                        result.push(next);
315                        i += 2;
316                        continue;
317                    }
318                };
319                result.push(decoded);
320                i += 2;
321            } else {
322                result.push(value[i]);
323                i += 1;
324            }
325        }
326
327        result
328    }
329
330    /// Encode string for COPY format (escape special characters)
331    fn encode_copy_escapes(&self, value: &str) -> Vec<u8> {
332        let mut result = Vec::with_capacity(value.len());
333
334        for b in value.bytes() {
335            match b {
336                b'\n' => result.extend_from_slice(b"\\n"),
337                b'\r' => result.extend_from_slice(b"\\r"),
338                b'\t' => result.extend_from_slice(b"\\t"),
339                b'\\' => result.extend_from_slice(b"\\\\"),
340                _ => result.push(b),
341            }
342        }
343
344        result
345    }
346
347    /// Extract column list from INSERT statement
348    fn extract_column_list(&self, stmt: &str) -> Option<Vec<String>> {
349        let upper = stmt.to_uppercase();
350        let values_pos = upper.find("VALUES")?;
351        let before_values = &stmt[..values_pos];
352
353        // Find the last (...) before VALUES
354        let close_paren = before_values.rfind(')')?;
355        let open_paren = before_values[..close_paren].rfind('(')?;
356
357        let col_list = &before_values[open_paren + 1..close_paren];
358
359        // Check if this looks like a column list
360        let upper_cols = col_list.to_uppercase();
361        if col_list.trim().is_empty()
362            || upper_cols.contains("SELECT")
363            || upper_cols.contains("VALUES")
364        {
365            return None;
366        }
367
368        let columns: Vec<String> = col_list
369            .split(',')
370            .map(|c| {
371                c.trim()
372                    .trim_matches('`')
373                    .trim_matches('"')
374                    .trim_matches('[')
375                    .trim_matches(']')
376                    .to_string()
377            })
378            .collect();
379
380        if columns.is_empty() {
381            None
382        } else {
383            Some(columns)
384        }
385    }
386
387    /// Build INSERT statement header
388    fn build_insert_header(&self, table_name: &str, columns: &Option<Vec<String>>) -> Vec<u8> {
389        let mut result = Vec::new();
390
391        // INSERT INTO table_name
392        result.extend_from_slice(b"INSERT INTO ");
393        result.extend_from_slice(self.quote_identifier(table_name).as_bytes());
394
395        // Optional column list
396        if let Some(cols) = columns {
397            result.extend_from_slice(b" (");
398            for (i, col) in cols.iter().enumerate() {
399                if i > 0 {
400                    result.extend_from_slice(b", ");
401                }
402                result.extend_from_slice(self.quote_identifier(col).as_bytes());
403            }
404            result.extend_from_slice(b")");
405        }
406
407        result.extend_from_slice(b" VALUES");
408        result
409    }
410
411    /// Quote an identifier based on dialect
412    fn quote_identifier(&self, name: &str) -> String {
413        match self.dialect {
414            SqlDialect::MySql => format!("`{}`", name),
415            SqlDialect::Postgres | SqlDialect::Sqlite => format!("\"{}\"", name),
416            SqlDialect::Mssql => format!("[{}]", name),
417        }
418    }
419
420    /// Redact a parsed value and format it for SQL output
421    fn redact_value(
422        &mut self,
423        value: &ParsedValue,
424        strategy: &StrategyKind,
425        has_strategy: bool,
426    ) -> (String, bool) {
427        // Skip strategy means no redaction
428        if !has_strategy || matches!(strategy, StrategyKind::Skip) {
429            return (self.format_value(value), false);
430        }
431
432        // Convert ParsedValue to RedactValue
433        let redact_value = self.parsed_to_redact(value);
434
435        // Apply the strategy
436        let result = self.apply_strategy(&redact_value, strategy);
437
438        // Format the result for SQL
439        (self.format_redact_value(&result), true)
440    }
441
442    /// Convert ParsedValue to RedactValue
443    fn parsed_to_redact(&self, value: &ParsedValue) -> RedactValue {
444        match value {
445            ParsedValue::Null => RedactValue::Null,
446            ParsedValue::Integer(n) => RedactValue::Integer(*n),
447            ParsedValue::BigInteger(n) => RedactValue::Integer(*n as i64), // Potential truncation
448            ParsedValue::String { value } => RedactValue::String(value.clone()),
449            ParsedValue::Hex(bytes) => RedactValue::Bytes(bytes.clone()),
450            ParsedValue::Other(bytes) => {
451                RedactValue::String(String::from_utf8_lossy(bytes).into_owned())
452            }
453        }
454    }
455
456    /// Apply a redaction strategy to a value
457    fn apply_strategy(&mut self, value: &RedactValue, strategy: &StrategyKind) -> RedactValue {
458        match strategy {
459            StrategyKind::Null => NullStrategy::new().apply(value, &mut self.rng),
460            StrategyKind::Constant { value: constant } => {
461                ConstantStrategy::new(constant.clone()).apply(value, &mut self.rng)
462            }
463            StrategyKind::Hash { preserve_domain } => {
464                HashStrategy::new(*preserve_domain).apply(value, &mut self.rng)
465            }
466            StrategyKind::Mask { pattern } => {
467                MaskStrategy::new(pattern.clone()).apply(value, &mut self.rng)
468            }
469            StrategyKind::Fake { generator } => {
470                FakeStrategy::new(generator.clone(), self.locale.clone()).apply(value, &mut self.rng)
471            }
472            StrategyKind::Shuffle => {
473                // Shuffle is special - needs column-level state
474                // For now, treat as skip (shuffle implemented at higher level)
475                value.clone()
476            }
477            StrategyKind::Skip => value.clone(),
478        }
479    }
480
481    /// Format a ParsedValue for SQL output
482    fn format_value(&self, value: &ParsedValue) -> String {
483        match value {
484            ParsedValue::Null => "NULL".to_string(),
485            ParsedValue::Integer(n) => n.to_string(),
486            ParsedValue::BigInteger(n) => n.to_string(),
487            ParsedValue::String { value } => self.format_sql_string(value),
488            ParsedValue::Hex(bytes) => String::from_utf8_lossy(bytes).into_owned(),
489            ParsedValue::Other(bytes) => String::from_utf8_lossy(bytes).into_owned(),
490        }
491    }
492
493    /// Format a RedactValue for SQL output
494    fn format_redact_value(&self, value: &RedactValue) -> String {
495        match value {
496            RedactValue::Null => "NULL".to_string(),
497            RedactValue::Integer(n) => n.to_string(),
498            RedactValue::String(s) => self.format_sql_string(s),
499            RedactValue::Bytes(b) => {
500                // Format as hex literal
501                format!("0x{}", hex::encode(b))
502            }
503        }
504    }
505
506    /// Format a string for SQL with proper escaping based on dialect
507    fn format_sql_string(&self, value: &str) -> String {
508        match self.dialect {
509            SqlDialect::MySql => {
510                // MySQL uses backslash escaping
511                let escaped = value
512                    .replace('\\', "\\\\")
513                    .replace('\'', "\\'")
514                    .replace('\n', "\\n")
515                    .replace('\r', "\\r")
516                    .replace('\t', "\\t")
517                    .replace('\0', "\\0");
518                format!("'{}'", escaped)
519            }
520            SqlDialect::Postgres | SqlDialect::Sqlite => {
521                // PostgreSQL/SQLite use doubled single quotes
522                let escaped = value.replace('\'', "''");
523                format!("'{}'", escaped)
524            }
525            SqlDialect::Mssql => {
526                // MSSQL uses N'...' for Unicode strings with doubled quotes
527                let escaped = value.replace('\'', "''");
528                // Use N'...' for non-ASCII or always for safety
529                if value.bytes().any(|b| b > 127) {
530                    format!("N'{}'", escaped)
531                } else {
532                    format!("'{}'", escaped)
533                }
534            }
535        }
536    }
537}
538
539/// Internal COPY value representation
540enum CopyValueRef {
541    Null,
542    Text(Vec<u8>),
543}
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548    use crate::schema::{Column, ColumnId, ColumnType, TableId, TableSchema};
549
550    fn create_test_schema() -> TableSchema {
551        TableSchema {
552            name: "users".to_string(),
553            id: TableId(0),
554            columns: vec![
555                Column {
556                    name: "id".to_string(),
557                    col_type: ColumnType::Int,
558                    ordinal: ColumnId(0),
559                    is_primary_key: true,
560                    is_nullable: false,
561                },
562                Column {
563                    name: "email".to_string(),
564                    col_type: ColumnType::Text,
565                    ordinal: ColumnId(1),
566                    is_primary_key: false,
567                    is_nullable: false,
568                },
569                Column {
570                    name: "name".to_string(),
571                    col_type: ColumnType::Text,
572                    ordinal: ColumnId(2),
573                    is_primary_key: false,
574                    is_nullable: true,
575                },
576            ],
577            primary_key: vec![ColumnId(0)],
578            foreign_keys: vec![],
579            indexes: vec![],
580            create_statement: None,
581        }
582    }
583
584    #[test]
585    fn test_rewrite_insert_mysql() {
586        let mut rewriter = ValueRewriter::new(Some(42), SqlDialect::MySql, "en".to_string());
587        let schema = create_test_schema();
588
589        let stmt = b"INSERT INTO `users` (`id`, `email`, `name`) VALUES (1, 'alice@example.com', 'Alice');";
590        let strategies = vec![
591            StrategyKind::Skip, // id
592            StrategyKind::Hash { preserve_domain: true }, // email
593            StrategyKind::Fake { generator: "name".to_string() }, // name
594        ];
595
596        let (result, rows, cols) = rewriter.rewrite_insert(stmt, "users", &schema, &strategies).unwrap();
597        let result_str = String::from_utf8_lossy(&result);
598
599        assert!(result_str.contains("INSERT INTO `users`"));
600        assert!(result_str.contains("VALUES"));
601        assert_eq!(rows, 1);
602        assert_eq!(cols, 2); // email and name were redacted
603    }
604
605    #[test]
606    fn test_rewrite_insert_mssql() {
607        let mut rewriter = ValueRewriter::new(Some(42), SqlDialect::Mssql, "en".to_string());
608        let schema = create_test_schema();
609
610        let stmt = b"INSERT INTO [users] ([id], [email], [name]) VALUES (1, N'alice@example.com', N'Alice');";
611        let strategies = vec![
612            StrategyKind::Skip, // id
613            StrategyKind::Null, // email
614            StrategyKind::Skip, // name
615        ];
616
617        let (result, rows, cols) = rewriter.rewrite_insert(stmt, "users", &schema, &strategies).unwrap();
618        let result_str = String::from_utf8_lossy(&result);
619
620        assert!(result_str.contains("INSERT INTO [users]"));
621        assert!(result_str.contains("NULL")); // email redacted to NULL
622        assert_eq!(rows, 1);
623        assert_eq!(cols, 1);
624    }
625
626    #[test]
627    fn test_format_sql_string_mysql() {
628        let rewriter = ValueRewriter::new(Some(42), SqlDialect::MySql, "en".to_string());
629        assert_eq!(rewriter.format_sql_string("hello"), "'hello'");
630        assert_eq!(rewriter.format_sql_string("it's"), "'it\\'s'");
631        assert_eq!(rewriter.format_sql_string("line\nbreak"), "'line\\nbreak'");
632    }
633
634    #[test]
635    fn test_format_sql_string_postgres() {
636        let rewriter = ValueRewriter::new(Some(42), SqlDialect::Postgres, "en".to_string());
637        assert_eq!(rewriter.format_sql_string("hello"), "'hello'");
638        assert_eq!(rewriter.format_sql_string("it's"), "'it''s'");
639    }
640
641    #[test]
642    fn test_format_sql_string_mssql() {
643        let rewriter = ValueRewriter::new(Some(42), SqlDialect::Mssql, "en".to_string());
644        assert_eq!(rewriter.format_sql_string("hello"), "'hello'");
645        assert_eq!(rewriter.format_sql_string("café"), "N'café'");
646    }
647
648    #[test]
649    fn test_quote_identifier() {
650        let mysql = ValueRewriter::new(None, SqlDialect::MySql, "en".to_string());
651        assert_eq!(mysql.quote_identifier("users"), "`users`");
652
653        let pg = ValueRewriter::new(None, SqlDialect::Postgres, "en".to_string());
654        assert_eq!(pg.quote_identifier("users"), "\"users\"");
655
656        let mssql = ValueRewriter::new(None, SqlDialect::Mssql, "en".to_string());
657        assert_eq!(mssql.quote_identifier("users"), "[users]");
658    }
659}