Skip to main content

sql_splitter/redactor/
rewriter.rs

1//! Value rewriter for INSERT and COPY statement redaction.
2//!
3//! Handles parsing values, applying redaction strategies, and formatting
4//! the redacted values back to SQL with proper dialect-aware escaping.
5
6use crate::parser::mysql_insert::{InsertParser, ParsedValue};
7use crate::parser::postgres_copy::{parse_copy_columns, CopyParser};
8use crate::parser::SqlDialect;
9use crate::redactor::strategy::{
10    ConstantStrategy, FakeStrategy, HashStrategy, MaskStrategy, NullStrategy, RedactValue,
11    Strategy, StrategyKind,
12};
13use crate::schema::TableSchema;
14use rand::rngs::StdRng;
15use rand::SeedableRng;
16
17/// Rewriter for INSERT and COPY statements
18pub struct ValueRewriter {
19    /// RNG for reproducible redaction
20    rng: StdRng,
21    /// Dialect for output formatting
22    dialect: SqlDialect,
23    /// Locale for fake data generation
24    locale: String,
25}
26
27impl ValueRewriter {
28    /// Create a new rewriter with optional seed for reproducibility
29    pub fn new(seed: Option<u64>, dialect: SqlDialect, locale: String) -> Self {
30        let rng = match seed {
31            Some(s) => StdRng::seed_from_u64(s),
32            None => StdRng::from_os_rng(),
33        };
34        Self {
35            rng,
36            dialect,
37            locale,
38        }
39    }
40
41    /// Rewrite an INSERT statement with redacted values
42    pub fn rewrite_insert(
43        &mut self,
44        stmt: &[u8],
45        table_name: &str,
46        table: &TableSchema,
47        strategies: &[StrategyKind],
48    ) -> anyhow::Result<(Vec<u8>, u64, u64)> {
49        // Parse the INSERT statement
50        let mut parser = InsertParser::new(stmt).with_schema(table);
51        let rows = parser.parse_rows()?;
52
53        if rows.is_empty() {
54            return Ok((stmt.to_vec(), 0, 0));
55        }
56
57        // Get the column list (if any) from the statement
58        let stmt_str = String::from_utf8_lossy(stmt);
59        let column_list = self.extract_column_list(&stmt_str);
60
61        // Build the header: INSERT INTO table_name (columns) VALUES
62        let mut result = self.build_insert_header(table_name, &column_list);
63
64        let mut rows_redacted = 0u64;
65        let mut columns_redacted = 0u64;
66        let num_strategies = strategies.len();
67
68        for (row_idx, row) in rows.iter().enumerate() {
69            if row_idx > 0 {
70                result.extend_from_slice(b",");
71            }
72            result.extend_from_slice(b"\n(");
73
74            let mut row_had_redaction = false;
75
76            for (col_idx, value) in row.values.iter().enumerate() {
77                if col_idx > 0 {
78                    result.extend_from_slice(b", ");
79                }
80
81                // Get strategy for this column (may be Skip if index out of bounds)
82                let strategy = strategies.get(col_idx).unwrap_or(&StrategyKind::Skip);
83
84                // Apply redaction
85                let (redacted_sql, was_redacted) =
86                    self.redact_value(value, strategy, col_idx < num_strategies);
87                result.extend_from_slice(redacted_sql.as_bytes());
88
89                if was_redacted {
90                    columns_redacted += 1;
91                    row_had_redaction = true;
92                }
93            }
94
95            result.extend_from_slice(b")");
96            if row_had_redaction {
97                rows_redacted += 1;
98            }
99        }
100
101        result.extend_from_slice(b";\n");
102
103        Ok((result, rows_redacted, columns_redacted))
104    }
105
106    /// Rewrite a COPY statement with redacted values (PostgreSQL)
107    pub fn rewrite_copy(
108        &mut self,
109        stmt: &[u8],
110        _table_name: &str,
111        table: &TableSchema,
112        strategies: &[StrategyKind],
113    ) -> anyhow::Result<(Vec<u8>, u64, u64)> {
114        // COPY statements include the header and data block
115        // Format: COPY table (cols) FROM stdin;\ndata\n\.\n
116
117        let stmt_str = String::from_utf8_lossy(stmt);
118
119        // Find the header line (ends with "FROM stdin;" or similar)
120        let header_end = stmt_str
121            .find('\n')
122            .ok_or_else(|| anyhow::anyhow!("Invalid COPY statement: no newline"))?;
123        let header = &stmt_str[..header_end];
124        let data_block = &stmt[header_end + 1..];
125
126        // Parse column list from header
127        let columns = parse_copy_columns(header);
128
129        // Parse data rows
130        let mut parser = CopyParser::new(data_block)
131            .with_schema(table)
132            .with_column_order(columns.clone());
133        let rows = parser.parse_rows()?;
134
135        if rows.is_empty() {
136            return Ok((stmt.to_vec(), 0, 0));
137        }
138
139        // Build result: header + redacted data + terminator
140        let mut result = Vec::with_capacity(stmt.len());
141        result.extend_from_slice(header.as_bytes());
142        result.push(b'\n');
143
144        let mut rows_redacted = 0u64;
145        let mut columns_redacted = 0u64;
146
147        for row in &rows {
148            let mut row_had_redaction = false;
149            let mut first = true;
150
151            // Parse the raw values from the row
152            let values = self.parse_copy_row_values(&row.raw);
153
154            for (col_idx, value) in values.iter().enumerate() {
155                if !first {
156                    result.push(b'\t');
157                }
158                first = false;
159
160                let strategy = strategies.get(col_idx).unwrap_or(&StrategyKind::Skip);
161                let (redacted, was_redacted) = self.redact_copy_value(value, strategy);
162                result.extend_from_slice(&redacted);
163
164                if was_redacted {
165                    columns_redacted += 1;
166                    row_had_redaction = true;
167                }
168            }
169
170            result.push(b'\n');
171            if row_had_redaction {
172                rows_redacted += 1;
173            }
174        }
175
176        // Add terminator
177        result.extend_from_slice(b"\\.\n");
178
179        Ok((result, rows_redacted, columns_redacted))
180    }
181
182    /// Rewrite just the COPY data block (header handled separately)
183    pub fn rewrite_copy_data(
184        &mut self,
185        data_block: &[u8],
186        table: &TableSchema,
187        strategies: &[StrategyKind],
188        columns: &[String],
189    ) -> anyhow::Result<(Vec<u8>, u64, u64)> {
190        // Parse data rows
191        let mut parser = CopyParser::new(data_block)
192            .with_schema(table)
193            .with_column_order(columns.to_vec());
194        let rows = parser.parse_rows()?;
195
196        if rows.is_empty() {
197            return Ok((data_block.to_vec(), 0, 0));
198        }
199
200        // Build result: redacted data + terminator
201        let mut result = Vec::with_capacity(data_block.len());
202
203        let mut rows_redacted = 0u64;
204        let mut columns_redacted = 0u64;
205
206        for row in &rows {
207            let mut row_had_redaction = false;
208            let mut first = true;
209
210            // Parse the raw values from the row
211            let values = self.parse_copy_row_values(&row.raw);
212
213            for (col_idx, value) in values.iter().enumerate() {
214                if !first {
215                    result.push(b'\t');
216                }
217                first = false;
218
219                let strategy = strategies.get(col_idx).unwrap_or(&StrategyKind::Skip);
220                let (redacted, was_redacted) = self.redact_copy_value(value, strategy);
221                result.extend_from_slice(&redacted);
222
223                if was_redacted {
224                    columns_redacted += 1;
225                    row_had_redaction = true;
226                }
227            }
228
229            result.push(b'\n');
230            if row_had_redaction {
231                rows_redacted += 1;
232            }
233        }
234
235        // Add terminator
236        result.extend_from_slice(b"\\.\n");
237
238        Ok((result, rows_redacted, columns_redacted))
239    }
240
241    /// Parse tab-separated values from a COPY row
242    fn parse_copy_row_values(&self, raw: &[u8]) -> Vec<CopyValueRef> {
243        let mut values = Vec::new();
244        let mut start = 0;
245
246        for (i, &b) in raw.iter().enumerate() {
247            if b == b'\t' {
248                values.push(self.parse_single_copy_value(&raw[start..i]));
249                start = i + 1;
250            }
251        }
252        // Last value
253        if start <= raw.len() {
254            values.push(self.parse_single_copy_value(&raw[start..]));
255        }
256
257        values
258    }
259
260    /// Parse a single COPY value
261    fn parse_single_copy_value(&self, raw: &[u8]) -> CopyValueRef {
262        if raw == b"\\N" {
263            CopyValueRef::Null
264        } else {
265            CopyValueRef::Text(raw.to_vec())
266        }
267    }
268
269    /// Redact a COPY value and return the redacted bytes
270    fn redact_copy_value(
271        &mut self,
272        value: &CopyValueRef,
273        strategy: &StrategyKind,
274    ) -> (Vec<u8>, bool) {
275        if matches!(strategy, StrategyKind::Skip) {
276            let bytes = match value {
277                CopyValueRef::Null => b"\\N".to_vec(),
278                CopyValueRef::Text(t) => t.clone(),
279            };
280            return (bytes, false);
281        }
282
283        // Convert to RedactValue
284        let redact_value = match value {
285            CopyValueRef::Null => RedactValue::Null,
286            CopyValueRef::Text(t) => {
287                // Decode escape sequences first
288                let decoded = self.decode_copy_escapes(t);
289                RedactValue::String(String::from_utf8_lossy(&decoded).into_owned())
290            }
291        };
292
293        // Apply strategy
294        let result = self.apply_strategy(&redact_value, strategy);
295
296        // Convert back to COPY format
297        let bytes = match result {
298            RedactValue::Null => b"\\N".to_vec(),
299            RedactValue::String(s) => self.encode_copy_escapes(&s),
300            RedactValue::Integer(i) => i.to_string().into_bytes(),
301            RedactValue::Bytes(b) => self.encode_copy_escapes(&String::from_utf8_lossy(&b)),
302        };
303
304        (bytes, true)
305    }
306
307    /// Decode PostgreSQL COPY escape sequences
308    fn decode_copy_escapes(&self, value: &[u8]) -> Vec<u8> {
309        let mut result = Vec::with_capacity(value.len());
310        let mut i = 0;
311
312        while i < value.len() {
313            if value[i] == b'\\' && i + 1 < value.len() {
314                let next = value[i + 1];
315                let decoded = match next {
316                    b'n' => b'\n',
317                    b'r' => b'\r',
318                    b't' => b'\t',
319                    b'\\' => b'\\',
320                    _ => {
321                        result.push(b'\\');
322                        result.push(next);
323                        i += 2;
324                        continue;
325                    }
326                };
327                result.push(decoded);
328                i += 2;
329            } else {
330                result.push(value[i]);
331                i += 1;
332            }
333        }
334
335        result
336    }
337
338    /// Encode string for COPY format (escape special characters)
339    fn encode_copy_escapes(&self, value: &str) -> Vec<u8> {
340        let mut result = Vec::with_capacity(value.len());
341
342        for b in value.bytes() {
343            match b {
344                b'\n' => result.extend_from_slice(b"\\n"),
345                b'\r' => result.extend_from_slice(b"\\r"),
346                b'\t' => result.extend_from_slice(b"\\t"),
347                b'\\' => result.extend_from_slice(b"\\\\"),
348                _ => result.push(b),
349            }
350        }
351
352        result
353    }
354
355    /// Extract column list from INSERT statement
356    fn extract_column_list(&self, stmt: &str) -> Option<Vec<String>> {
357        let upper = stmt.to_uppercase();
358        let values_pos = upper.find("VALUES")?;
359        let before_values = &stmt[..values_pos];
360
361        // Find the last (...) before VALUES
362        let close_paren = before_values.rfind(')')?;
363        let open_paren = before_values[..close_paren].rfind('(')?;
364
365        let col_list = &before_values[open_paren + 1..close_paren];
366
367        // Check if this looks like a column list
368        let upper_cols = col_list.to_uppercase();
369        if col_list.trim().is_empty()
370            || upper_cols.contains("SELECT")
371            || upper_cols.contains("VALUES")
372        {
373            return None;
374        }
375
376        let columns: Vec<String> = col_list
377            .split(',')
378            .map(|c| {
379                c.trim()
380                    .trim_matches('`')
381                    .trim_matches('"')
382                    .trim_matches('[')
383                    .trim_matches(']')
384                    .to_string()
385            })
386            .collect();
387
388        if columns.is_empty() {
389            None
390        } else {
391            Some(columns)
392        }
393    }
394
395    /// Build INSERT statement header
396    fn build_insert_header(&self, table_name: &str, columns: &Option<Vec<String>>) -> Vec<u8> {
397        let mut result = Vec::new();
398
399        // INSERT INTO table_name
400        result.extend_from_slice(b"INSERT INTO ");
401        result.extend_from_slice(self.quote_identifier(table_name).as_bytes());
402
403        // Optional column list
404        if let Some(cols) = columns {
405            result.extend_from_slice(b" (");
406            for (i, col) in cols.iter().enumerate() {
407                if i > 0 {
408                    result.extend_from_slice(b", ");
409                }
410                result.extend_from_slice(self.quote_identifier(col).as_bytes());
411            }
412            result.extend_from_slice(b")");
413        }
414
415        result.extend_from_slice(b" VALUES");
416        result
417    }
418
419    /// Quote an identifier based on dialect
420    fn quote_identifier(&self, name: &str) -> String {
421        match self.dialect {
422            SqlDialect::MySql => format!("`{}`", name),
423            SqlDialect::Postgres | SqlDialect::Sqlite => format!("\"{}\"", name),
424            SqlDialect::Mssql => format!("[{}]", name),
425        }
426    }
427
428    /// Redact a parsed value and format it for SQL output
429    fn redact_value(
430        &mut self,
431        value: &ParsedValue,
432        strategy: &StrategyKind,
433        has_strategy: bool,
434    ) -> (String, bool) {
435        // Skip strategy means no redaction
436        if !has_strategy || matches!(strategy, StrategyKind::Skip) {
437            return (self.format_value(value), false);
438        }
439
440        // Convert ParsedValue to RedactValue
441        let redact_value = self.parsed_to_redact(value);
442
443        // Apply the strategy
444        let result = self.apply_strategy(&redact_value, strategy);
445
446        // Format the result for SQL
447        (self.format_redact_value(&result), true)
448    }
449
450    /// Convert ParsedValue to RedactValue
451    fn parsed_to_redact(&self, value: &ParsedValue) -> RedactValue {
452        match value {
453            ParsedValue::Null => RedactValue::Null,
454            ParsedValue::Integer(n) => RedactValue::Integer(*n),
455            ParsedValue::BigInteger(n) => RedactValue::Integer(*n as i64), // Potential truncation
456            ParsedValue::String { value } => RedactValue::String(value.clone()),
457            ParsedValue::Hex(bytes) => RedactValue::Bytes(bytes.clone()),
458            ParsedValue::Other(bytes) => {
459                RedactValue::String(String::from_utf8_lossy(bytes).into_owned())
460            }
461        }
462    }
463
464    /// Apply a redaction strategy to a value
465    fn apply_strategy(&mut self, value: &RedactValue, strategy: &StrategyKind) -> RedactValue {
466        match strategy {
467            StrategyKind::Null => NullStrategy::new().apply(value, &mut self.rng),
468            StrategyKind::Constant { value: constant } => {
469                ConstantStrategy::new(constant.clone()).apply(value, &mut self.rng)
470            }
471            StrategyKind::Hash { preserve_domain } => {
472                HashStrategy::new(*preserve_domain).apply(value, &mut self.rng)
473            }
474            StrategyKind::Mask { pattern } => {
475                MaskStrategy::new(pattern.clone()).apply(value, &mut self.rng)
476            }
477            StrategyKind::Fake { generator } => {
478                FakeStrategy::new(generator.clone(), self.locale.clone())
479                    .apply(value, &mut self.rng)
480            }
481            StrategyKind::Shuffle => {
482                // Shuffle is special - needs column-level state
483                // For now, treat as skip (shuffle implemented at higher level)
484                value.clone()
485            }
486            StrategyKind::Skip => value.clone(),
487        }
488    }
489
490    /// Format a ParsedValue for SQL output
491    fn format_value(&self, value: &ParsedValue) -> String {
492        match value {
493            ParsedValue::Null => "NULL".to_string(),
494            ParsedValue::Integer(n) => n.to_string(),
495            ParsedValue::BigInteger(n) => n.to_string(),
496            ParsedValue::String { value } => self.format_sql_string(value),
497            ParsedValue::Hex(bytes) => String::from_utf8_lossy(bytes).into_owned(),
498            ParsedValue::Other(bytes) => String::from_utf8_lossy(bytes).into_owned(),
499        }
500    }
501
502    /// Format a RedactValue for SQL output
503    fn format_redact_value(&self, value: &RedactValue) -> String {
504        match value {
505            RedactValue::Null => "NULL".to_string(),
506            RedactValue::Integer(n) => n.to_string(),
507            RedactValue::String(s) => self.format_sql_string(s),
508            RedactValue::Bytes(b) => {
509                // Format as hex literal
510                format!("0x{}", hex::encode(b))
511            }
512        }
513    }
514
515    /// Format a string for SQL with proper escaping based on dialect
516    fn format_sql_string(&self, value: &str) -> String {
517        match self.dialect {
518            SqlDialect::MySql => {
519                // MySQL uses backslash escaping
520                let escaped = value
521                    .replace('\\', "\\\\")
522                    .replace('\'', "\\'")
523                    .replace('\n', "\\n")
524                    .replace('\r', "\\r")
525                    .replace('\t', "\\t")
526                    .replace('\0', "\\0");
527                format!("'{}'", escaped)
528            }
529            SqlDialect::Postgres | SqlDialect::Sqlite => {
530                // PostgreSQL/SQLite use doubled single quotes
531                let escaped = value.replace('\'', "''");
532                format!("'{}'", escaped)
533            }
534            SqlDialect::Mssql => {
535                // MSSQL uses N'...' for Unicode strings with doubled quotes
536                let escaped = value.replace('\'', "''");
537                // Use N'...' for non-ASCII or always for safety
538                if value.bytes().any(|b| b > 127) {
539                    format!("N'{}'", escaped)
540                } else {
541                    format!("'{}'", escaped)
542                }
543            }
544        }
545    }
546}
547
548/// Internal COPY value representation
549enum CopyValueRef {
550    Null,
551    Text(Vec<u8>),
552}
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557    use crate::schema::{Column, ColumnId, ColumnType, TableId, TableSchema};
558
559    fn create_test_schema() -> TableSchema {
560        TableSchema {
561            name: "users".to_string(),
562            id: TableId(0),
563            columns: vec![
564                Column {
565                    name: "id".to_string(),
566                    col_type: ColumnType::Int,
567                    ordinal: ColumnId(0),
568                    is_primary_key: true,
569                    is_nullable: false,
570                },
571                Column {
572                    name: "email".to_string(),
573                    col_type: ColumnType::Text,
574                    ordinal: ColumnId(1),
575                    is_primary_key: false,
576                    is_nullable: false,
577                },
578                Column {
579                    name: "name".to_string(),
580                    col_type: ColumnType::Text,
581                    ordinal: ColumnId(2),
582                    is_primary_key: false,
583                    is_nullable: true,
584                },
585            ],
586            primary_key: vec![ColumnId(0)],
587            foreign_keys: vec![],
588            indexes: vec![],
589            create_statement: None,
590        }
591    }
592
593    #[test]
594    fn test_rewrite_insert_mysql() {
595        let mut rewriter = ValueRewriter::new(Some(42), SqlDialect::MySql, "en".to_string());
596        let schema = create_test_schema();
597
598        let stmt = b"INSERT INTO `users` (`id`, `email`, `name`) VALUES (1, 'alice@example.com', 'Alice');";
599        let strategies = vec![
600            StrategyKind::Skip, // id
601            StrategyKind::Hash {
602                preserve_domain: true,
603            }, // email
604            StrategyKind::Fake {
605                generator: "name".to_string(),
606            }, // name
607        ];
608
609        let (result, rows, cols) = rewriter
610            .rewrite_insert(stmt, "users", &schema, &strategies)
611            .unwrap();
612        let result_str = String::from_utf8_lossy(&result);
613
614        assert!(result_str.contains("INSERT INTO `users`"));
615        assert!(result_str.contains("VALUES"));
616        assert_eq!(rows, 1);
617        assert_eq!(cols, 2); // email and name were redacted
618    }
619
620    #[test]
621    fn test_rewrite_insert_mssql() {
622        let mut rewriter = ValueRewriter::new(Some(42), SqlDialect::Mssql, "en".to_string());
623        let schema = create_test_schema();
624
625        let stmt = b"INSERT INTO [users] ([id], [email], [name]) VALUES (1, N'alice@example.com', N'Alice');";
626        let strategies = vec![
627            StrategyKind::Skip, // id
628            StrategyKind::Null, // email
629            StrategyKind::Skip, // name
630        ];
631
632        let (result, rows, cols) = rewriter
633            .rewrite_insert(stmt, "users", &schema, &strategies)
634            .unwrap();
635        let result_str = String::from_utf8_lossy(&result);
636
637        assert!(result_str.contains("INSERT INTO [users]"));
638        assert!(result_str.contains("NULL")); // email redacted to NULL
639        assert_eq!(rows, 1);
640        assert_eq!(cols, 1);
641    }
642
643    #[test]
644    fn test_format_sql_string_mysql() {
645        let rewriter = ValueRewriter::new(Some(42), SqlDialect::MySql, "en".to_string());
646        assert_eq!(rewriter.format_sql_string("hello"), "'hello'");
647        assert_eq!(rewriter.format_sql_string("it's"), "'it\\'s'");
648        assert_eq!(rewriter.format_sql_string("line\nbreak"), "'line\\nbreak'");
649    }
650
651    #[test]
652    fn test_format_sql_string_postgres() {
653        let rewriter = ValueRewriter::new(Some(42), SqlDialect::Postgres, "en".to_string());
654        assert_eq!(rewriter.format_sql_string("hello"), "'hello'");
655        assert_eq!(rewriter.format_sql_string("it's"), "'it''s'");
656    }
657
658    #[test]
659    fn test_format_sql_string_mssql() {
660        let rewriter = ValueRewriter::new(Some(42), SqlDialect::Mssql, "en".to_string());
661        assert_eq!(rewriter.format_sql_string("hello"), "'hello'");
662        assert_eq!(rewriter.format_sql_string("café"), "N'café'");
663    }
664
665    #[test]
666    fn test_quote_identifier() {
667        let mysql = ValueRewriter::new(None, SqlDialect::MySql, "en".to_string());
668        assert_eq!(mysql.quote_identifier("users"), "`users`");
669
670        let pg = ValueRewriter::new(None, SqlDialect::Postgres, "en".to_string());
671        assert_eq!(pg.quote_identifier("users"), "\"users\"");
672
673        let mssql = ValueRewriter::new(None, SqlDialect::Mssql, "en".to_string());
674        assert_eq!(mssql.quote_identifier("users"), "[users]");
675    }
676}