Skip to main content

sql_splitter/redactor/
matcher.rs

1//! Column pattern matching for redaction rules.
2
3use crate::redactor::config::{RedactConfig, Rule};
4use crate::redactor::StrategyKind;
5use crate::schema::TableSchema;
6use glob::Pattern;
7
8/// Compiled column matcher for efficient pattern matching
9#[derive(Debug)]
10pub struct ColumnMatcher {
11    /// Compiled rules with glob patterns
12    rules: Vec<CompiledRule>,
13    /// Default strategy for unmatched columns
14    default_strategy: StrategyKind,
15}
16
17/// A rule with pre-compiled glob patterns
18#[derive(Debug)]
19struct CompiledRule {
20    /// Table pattern (None = match all tables)
21    table_pattern: Option<Pattern>,
22    /// Column pattern
23    column_pattern: Pattern,
24    /// The strategy to apply
25    strategy: StrategyKind,
26}
27
28impl ColumnMatcher {
29    /// Create a new matcher from configuration
30    pub fn from_config(config: &RedactConfig) -> anyhow::Result<Self> {
31        let mut rules = Vec::with_capacity(config.rules.len());
32
33        for rule in &config.rules {
34            let compiled = Self::compile_rule(rule)?;
35            rules.push(compiled);
36        }
37
38        Ok(Self {
39            rules,
40            default_strategy: config.default_strategy.clone(),
41        })
42    }
43
44    /// Compile a rule into table and column patterns
45    fn compile_rule(rule: &Rule) -> anyhow::Result<CompiledRule> {
46        let pattern = &rule.column;
47
48        // Check if pattern contains a table qualifier (table.column)
49        let (table_pattern, column_pattern) = if let Some(dot_pos) = pattern.find('.') {
50            let table_part = &pattern[..dot_pos];
51            let column_part = &pattern[dot_pos + 1..];
52
53            // Compile table pattern (might be * for all tables)
54            let table_pat = if table_part == "*" {
55                None
56            } else {
57                Some(Pattern::new(table_part).map_err(|e| {
58                    anyhow::anyhow!("Invalid table pattern '{}': {}", table_part, e)
59                })?)
60            };
61
62            let col_pat = Pattern::new(column_part)
63                .map_err(|e| anyhow::anyhow!("Invalid column pattern '{}': {}", column_part, e))?;
64
65            (table_pat, col_pat)
66        } else {
67            // No table qualifier - match all tables
68            let col_pat = Pattern::new(pattern)
69                .map_err(|e| anyhow::anyhow!("Invalid column pattern '{}': {}", pattern, e))?;
70            (None, col_pat)
71        };
72
73        Ok(CompiledRule {
74            table_pattern,
75            column_pattern,
76            strategy: rule.strategy.clone(),
77        })
78    }
79
80    /// Get the strategy for a specific column
81    pub fn get_strategy(&self, table_name: &str, column_name: &str) -> StrategyKind {
82        // Find first matching rule (rules are processed in order)
83        for rule in &self.rules {
84            if self.rule_matches(rule, table_name, column_name) {
85                return rule.strategy.clone();
86            }
87        }
88
89        // No match - return default
90        self.default_strategy.clone()
91    }
92
93    /// Get strategies for all columns in a table
94    pub fn get_strategies(&self, table_name: &str, table: &TableSchema) -> Vec<StrategyKind> {
95        table
96            .columns
97            .iter()
98            .map(|col| self.get_strategy(table_name, &col.name))
99            .collect()
100    }
101
102    /// Count how many columns match any redaction rule
103    pub fn count_matches(&self, table_name: &str, table: &TableSchema) -> usize {
104        table
105            .columns
106            .iter()
107            .filter(|col| {
108                let strategy = self.get_strategy(table_name, &col.name);
109                !matches!(strategy, StrategyKind::Skip)
110            })
111            .count()
112    }
113
114    /// Check if a rule matches a table/column pair
115    fn rule_matches(&self, rule: &CompiledRule, table_name: &str, column_name: &str) -> bool {
116        // Check table pattern (if specified)
117        if let Some(ref table_pat) = rule.table_pattern {
118            if !table_pat.matches(table_name) && !table_pat.matches(&table_name.to_lowercase()) {
119                return false;
120            }
121        }
122
123        // Check column pattern
124        rule.column_pattern.matches(column_name)
125            || rule.column_pattern.matches(&column_name.to_lowercase())
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    fn create_test_schema() -> TableSchema {
134        use crate::schema::{Column, ColumnId, ColumnType, TableId};
135
136        TableSchema {
137            name: "users".to_string(),
138            id: TableId(0),
139            columns: vec![
140                Column {
141                    name: "id".to_string(),
142                    col_type: ColumnType::Int,
143                    ordinal: ColumnId(0),
144                    is_primary_key: true,
145                    is_nullable: false,
146                },
147                Column {
148                    name: "email".to_string(),
149                    col_type: ColumnType::Text,
150                    ordinal: ColumnId(1),
151                    is_primary_key: false,
152                    is_nullable: false,
153                },
154                Column {
155                    name: "name".to_string(),
156                    col_type: ColumnType::Text,
157                    ordinal: ColumnId(2),
158                    is_primary_key: false,
159                    is_nullable: true,
160                },
161                Column {
162                    name: "ssn".to_string(),
163                    col_type: ColumnType::Text,
164                    ordinal: ColumnId(3),
165                    is_primary_key: false,
166                    is_nullable: true,
167                },
168            ],
169            primary_key: vec![ColumnId(0)],
170            foreign_keys: vec![],
171            indexes: vec![],
172            create_statement: None,
173        }
174    }
175
176    #[test]
177    fn test_wildcard_column_match() {
178        let config = RedactConfig {
179            input: std::path::PathBuf::new(),
180            output: None,
181            dialect: crate::parser::SqlDialect::MySql,
182            rules: vec![Rule {
183                column: "*.email".to_string(),
184                strategy: StrategyKind::Hash {
185                    preserve_domain: true,
186                },
187            }],
188            default_strategy: StrategyKind::Skip,
189            seed: None,
190            locale: "en".to_string(),
191            tables_filter: None,
192            exclude: vec![],
193            strict: false,
194            progress: false,
195            dry_run: false,
196        };
197
198        let matcher = ColumnMatcher::from_config(&config).unwrap();
199        let schema = create_test_schema();
200
201        let strategies = matcher.get_strategies("users", &schema);
202
203        // id: skip, email: hash, name: skip, ssn: skip
204        assert!(matches!(strategies[0], StrategyKind::Skip));
205        assert!(matches!(strategies[1], StrategyKind::Hash { .. }));
206        assert!(matches!(strategies[2], StrategyKind::Skip));
207        assert!(matches!(strategies[3], StrategyKind::Skip));
208    }
209
210    #[test]
211    fn test_exact_column_match() {
212        let config = RedactConfig {
213            input: std::path::PathBuf::new(),
214            output: None,
215            dialect: crate::parser::SqlDialect::MySql,
216            rules: vec![Rule {
217                column: "users.ssn".to_string(),
218                strategy: StrategyKind::Null,
219            }],
220            default_strategy: StrategyKind::Skip,
221            seed: None,
222            locale: "en".to_string(),
223            tables_filter: None,
224            exclude: vec![],
225            strict: false,
226            progress: false,
227            dry_run: false,
228        };
229
230        let matcher = ColumnMatcher::from_config(&config).unwrap();
231
232        // Should match users.ssn
233        let strategy = matcher.get_strategy("users", "ssn");
234        assert!(matches!(strategy, StrategyKind::Null));
235
236        // Should NOT match other_table.ssn
237        let strategy = matcher.get_strategy("other_table", "ssn");
238        assert!(matches!(strategy, StrategyKind::Skip));
239    }
240
241    #[test]
242    fn test_rule_priority() {
243        let config = RedactConfig {
244            input: std::path::PathBuf::new(),
245            output: None,
246            dialect: crate::parser::SqlDialect::MySql,
247            rules: vec![
248                // More specific rule first
249                Rule {
250                    column: "admins.email".to_string(),
251                    strategy: StrategyKind::Skip,
252                },
253                // General rule second
254                Rule {
255                    column: "*.email".to_string(),
256                    strategy: StrategyKind::Hash {
257                        preserve_domain: false,
258                    },
259                },
260            ],
261            default_strategy: StrategyKind::Skip,
262            seed: None,
263            locale: "en".to_string(),
264            tables_filter: None,
265            exclude: vec![],
266            strict: false,
267            progress: false,
268            dry_run: false,
269        };
270
271        let matcher = ColumnMatcher::from_config(&config).unwrap();
272
273        // admins.email should skip (first rule)
274        let strategy = matcher.get_strategy("admins", "email");
275        assert!(matches!(strategy, StrategyKind::Skip));
276
277        // users.email should hash (second rule)
278        let strategy = matcher.get_strategy("users", "email");
279        assert!(matches!(strategy, StrategyKind::Hash { .. }));
280    }
281
282    #[test]
283    fn test_count_matches() {
284        let config = RedactConfig {
285            input: std::path::PathBuf::new(),
286            output: None,
287            dialect: crate::parser::SqlDialect::MySql,
288            rules: vec![
289                Rule {
290                    column: "*.email".to_string(),
291                    strategy: StrategyKind::Hash {
292                        preserve_domain: false,
293                    },
294                },
295                Rule {
296                    column: "*.ssn".to_string(),
297                    strategy: StrategyKind::Null,
298                },
299            ],
300            default_strategy: StrategyKind::Skip,
301            seed: None,
302            locale: "en".to_string(),
303            tables_filter: None,
304            exclude: vec![],
305            strict: false,
306            progress: false,
307            dry_run: false,
308        };
309
310        let matcher = ColumnMatcher::from_config(&config).unwrap();
311        let schema = create_test_schema();
312
313        // Should match email and ssn (2 columns)
314        assert_eq!(matcher.count_matches("users", &schema), 2);
315    }
316}