sql_splitter/redactor/
matcher.rs

1//! Column pattern matching for redaction rules.
2
3use crate::redactor::config::{RedactConfig, Rule};
4use crate::redactor::StrategyKind;
5use crate::schema::TableSchema;
6use glob::Pattern;
7
8/// Compiled column matcher for efficient pattern matching
9#[derive(Debug)]
10pub struct ColumnMatcher {
11    /// Compiled rules with glob patterns
12    rules: Vec<CompiledRule>,
13    /// Default strategy for unmatched columns
14    default_strategy: StrategyKind,
15}
16
17/// A rule with pre-compiled glob patterns
18#[derive(Debug)]
19struct CompiledRule {
20    /// Table pattern (None = match all tables)
21    table_pattern: Option<Pattern>,
22    /// Column pattern
23    column_pattern: Pattern,
24    /// The strategy to apply
25    strategy: StrategyKind,
26}
27
28impl ColumnMatcher {
29    /// Create a new matcher from configuration
30    pub fn from_config(config: &RedactConfig) -> anyhow::Result<Self> {
31        let mut rules = Vec::with_capacity(config.rules.len());
32
33        for rule in &config.rules {
34            let compiled = Self::compile_rule(rule)?;
35            rules.push(compiled);
36        }
37
38        Ok(Self {
39            rules,
40            default_strategy: config.default_strategy.clone(),
41        })
42    }
43
44    /// Compile a rule into table and column patterns
45    fn compile_rule(rule: &Rule) -> anyhow::Result<CompiledRule> {
46        let pattern = &rule.column;
47
48        // Check if pattern contains a table qualifier (table.column)
49        let (table_pattern, column_pattern) = if let Some(dot_pos) = pattern.find('.') {
50            let table_part = &pattern[..dot_pos];
51            let column_part = &pattern[dot_pos + 1..];
52
53            // Compile table pattern (might be * for all tables)
54            let table_pat = if table_part == "*" {
55                None
56            } else {
57                Some(
58                    Pattern::new(table_part)
59                        .map_err(|e| anyhow::anyhow!("Invalid table pattern '{}': {}", table_part, e))?,
60                )
61            };
62
63            let col_pat = Pattern::new(column_part)
64                .map_err(|e| anyhow::anyhow!("Invalid column pattern '{}': {}", column_part, e))?;
65
66            (table_pat, col_pat)
67        } else {
68            // No table qualifier - match all tables
69            let col_pat = Pattern::new(pattern)
70                .map_err(|e| anyhow::anyhow!("Invalid column pattern '{}': {}", pattern, e))?;
71            (None, col_pat)
72        };
73
74        Ok(CompiledRule {
75            table_pattern,
76            column_pattern,
77            strategy: rule.strategy.clone(),
78        })
79    }
80
81    /// Get the strategy for a specific column
82    pub fn get_strategy(&self, table_name: &str, column_name: &str) -> StrategyKind {
83        // Find first matching rule (rules are processed in order)
84        for rule in &self.rules {
85            if self.rule_matches(rule, table_name, column_name) {
86                return rule.strategy.clone();
87            }
88        }
89
90        // No match - return default
91        self.default_strategy.clone()
92    }
93
94    /// Get strategies for all columns in a table
95    pub fn get_strategies(&self, table_name: &str, table: &TableSchema) -> Vec<StrategyKind> {
96        table
97            .columns
98            .iter()
99            .map(|col| self.get_strategy(table_name, &col.name))
100            .collect()
101    }
102
103    /// Count how many columns match any redaction rule
104    pub fn count_matches(&self, table_name: &str, table: &TableSchema) -> usize {
105        table
106            .columns
107            .iter()
108            .filter(|col| {
109                let strategy = self.get_strategy(table_name, &col.name);
110                !matches!(strategy, StrategyKind::Skip)
111            })
112            .count()
113    }
114
115    /// Check if a rule matches a table/column pair
116    fn rule_matches(&self, rule: &CompiledRule, table_name: &str, column_name: &str) -> bool {
117        // Check table pattern (if specified)
118        if let Some(ref table_pat) = rule.table_pattern {
119            if !table_pat.matches(table_name) && !table_pat.matches(&table_name.to_lowercase()) {
120                return false;
121            }
122        }
123
124        // Check column pattern
125        rule.column_pattern.matches(column_name)
126            || rule.column_pattern.matches(&column_name.to_lowercase())
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    fn create_test_schema() -> TableSchema {
135        use crate::schema::{Column, ColumnId, ColumnType, TableId};
136
137        TableSchema {
138            name: "users".to_string(),
139            id: TableId(0),
140            columns: vec![
141                Column {
142                    name: "id".to_string(),
143                    col_type: ColumnType::Int,
144                    ordinal: ColumnId(0),
145                    is_primary_key: true,
146                    is_nullable: false,
147                },
148                Column {
149                    name: "email".to_string(),
150                    col_type: ColumnType::Text,
151                    ordinal: ColumnId(1),
152                    is_primary_key: false,
153                    is_nullable: false,
154                },
155                Column {
156                    name: "name".to_string(),
157                    col_type: ColumnType::Text,
158                    ordinal: ColumnId(2),
159                    is_primary_key: false,
160                    is_nullable: true,
161                },
162                Column {
163                    name: "ssn".to_string(),
164                    col_type: ColumnType::Text,
165                    ordinal: ColumnId(3),
166                    is_primary_key: false,
167                    is_nullable: true,
168                },
169            ],
170            primary_key: vec![ColumnId(0)],
171            foreign_keys: vec![],
172            indexes: vec![],
173            create_statement: None,
174        }
175    }
176
177    #[test]
178    fn test_wildcard_column_match() {
179        let config = RedactConfig {
180            input: std::path::PathBuf::new(),
181            output: None,
182            dialect: crate::parser::SqlDialect::MySql,
183            rules: vec![Rule {
184                column: "*.email".to_string(),
185                strategy: StrategyKind::Hash { preserve_domain: true },
186            }],
187            default_strategy: StrategyKind::Skip,
188            seed: None,
189            locale: "en".to_string(),
190            tables_filter: None,
191            exclude: vec![],
192            strict: false,
193            progress: false,
194            dry_run: false,
195        };
196
197        let matcher = ColumnMatcher::from_config(&config).unwrap();
198        let schema = create_test_schema();
199
200        let strategies = matcher.get_strategies("users", &schema);
201
202        // id: skip, email: hash, name: skip, ssn: skip
203        assert!(matches!(strategies[0], StrategyKind::Skip));
204        assert!(matches!(strategies[1], StrategyKind::Hash { .. }));
205        assert!(matches!(strategies[2], StrategyKind::Skip));
206        assert!(matches!(strategies[3], StrategyKind::Skip));
207    }
208
209    #[test]
210    fn test_exact_column_match() {
211        let config = RedactConfig {
212            input: std::path::PathBuf::new(),
213            output: None,
214            dialect: crate::parser::SqlDialect::MySql,
215            rules: vec![Rule {
216                column: "users.ssn".to_string(),
217                strategy: StrategyKind::Null,
218            }],
219            default_strategy: StrategyKind::Skip,
220            seed: None,
221            locale: "en".to_string(),
222            tables_filter: None,
223            exclude: vec![],
224            strict: false,
225            progress: false,
226            dry_run: false,
227        };
228
229        let matcher = ColumnMatcher::from_config(&config).unwrap();
230
231        // Should match users.ssn
232        let strategy = matcher.get_strategy("users", "ssn");
233        assert!(matches!(strategy, StrategyKind::Null));
234
235        // Should NOT match other_table.ssn
236        let strategy = matcher.get_strategy("other_table", "ssn");
237        assert!(matches!(strategy, StrategyKind::Skip));
238    }
239
240    #[test]
241    fn test_rule_priority() {
242        let config = RedactConfig {
243            input: std::path::PathBuf::new(),
244            output: None,
245            dialect: crate::parser::SqlDialect::MySql,
246            rules: vec![
247                // More specific rule first
248                Rule {
249                    column: "admins.email".to_string(),
250                    strategy: StrategyKind::Skip,
251                },
252                // General rule second
253                Rule {
254                    column: "*.email".to_string(),
255                    strategy: StrategyKind::Hash { preserve_domain: false },
256                },
257            ],
258            default_strategy: StrategyKind::Skip,
259            seed: None,
260            locale: "en".to_string(),
261            tables_filter: None,
262            exclude: vec![],
263            strict: false,
264            progress: false,
265            dry_run: false,
266        };
267
268        let matcher = ColumnMatcher::from_config(&config).unwrap();
269
270        // admins.email should skip (first rule)
271        let strategy = matcher.get_strategy("admins", "email");
272        assert!(matches!(strategy, StrategyKind::Skip));
273
274        // users.email should hash (second rule)
275        let strategy = matcher.get_strategy("users", "email");
276        assert!(matches!(strategy, StrategyKind::Hash { .. }));
277    }
278
279    #[test]
280    fn test_count_matches() {
281        let config = RedactConfig {
282            input: std::path::PathBuf::new(),
283            output: None,
284            dialect: crate::parser::SqlDialect::MySql,
285            rules: vec![
286                Rule {
287                    column: "*.email".to_string(),
288                    strategy: StrategyKind::Hash { preserve_domain: false },
289                },
290                Rule {
291                    column: "*.ssn".to_string(),
292                    strategy: StrategyKind::Null,
293                },
294            ],
295            default_strategy: StrategyKind::Skip,
296            seed: None,
297            locale: "en".to_string(),
298            tables_filter: None,
299            exclude: vec![],
300            strict: false,
301            progress: false,
302            dry_run: false,
303        };
304
305        let matcher = ColumnMatcher::from_config(&config).unwrap();
306        let schema = create_test_schema();
307
308        // Should match email and ssn (2 columns)
309        assert_eq!(matcher.count_matches("users", &schema), 2);
310    }
311}