1use crate::redactor::config::{RedactConfig, Rule};
4use crate::redactor::StrategyKind;
5use crate::schema::TableSchema;
6use glob::Pattern;
7
8#[derive(Debug)]
10pub struct ColumnMatcher {
11 rules: Vec<CompiledRule>,
13 default_strategy: StrategyKind,
15}
16
17#[derive(Debug)]
19struct CompiledRule {
20 table_pattern: Option<Pattern>,
22 column_pattern: Pattern,
24 strategy: StrategyKind,
26}
27
28impl ColumnMatcher {
29 pub fn from_config(config: &RedactConfig) -> anyhow::Result<Self> {
31 let mut rules = Vec::with_capacity(config.rules.len());
32
33 for rule in &config.rules {
34 let compiled = Self::compile_rule(rule)?;
35 rules.push(compiled);
36 }
37
38 Ok(Self {
39 rules,
40 default_strategy: config.default_strategy.clone(),
41 })
42 }
43
44 fn compile_rule(rule: &Rule) -> anyhow::Result<CompiledRule> {
46 let pattern = &rule.column;
47
48 let (table_pattern, column_pattern) = if let Some(dot_pos) = pattern.find('.') {
50 let table_part = &pattern[..dot_pos];
51 let column_part = &pattern[dot_pos + 1..];
52
53 let table_pat = if table_part == "*" {
55 None
56 } else {
57 Some(
58 Pattern::new(table_part)
59 .map_err(|e| anyhow::anyhow!("Invalid table pattern '{}': {}", table_part, e))?,
60 )
61 };
62
63 let col_pat = Pattern::new(column_part)
64 .map_err(|e| anyhow::anyhow!("Invalid column pattern '{}': {}", column_part, e))?;
65
66 (table_pat, col_pat)
67 } else {
68 let col_pat = Pattern::new(pattern)
70 .map_err(|e| anyhow::anyhow!("Invalid column pattern '{}': {}", pattern, e))?;
71 (None, col_pat)
72 };
73
74 Ok(CompiledRule {
75 table_pattern,
76 column_pattern,
77 strategy: rule.strategy.clone(),
78 })
79 }
80
81 pub fn get_strategy(&self, table_name: &str, column_name: &str) -> StrategyKind {
83 for rule in &self.rules {
85 if self.rule_matches(rule, table_name, column_name) {
86 return rule.strategy.clone();
87 }
88 }
89
90 self.default_strategy.clone()
92 }
93
94 pub fn get_strategies(&self, table_name: &str, table: &TableSchema) -> Vec<StrategyKind> {
96 table
97 .columns
98 .iter()
99 .map(|col| self.get_strategy(table_name, &col.name))
100 .collect()
101 }
102
103 pub fn count_matches(&self, table_name: &str, table: &TableSchema) -> usize {
105 table
106 .columns
107 .iter()
108 .filter(|col| {
109 let strategy = self.get_strategy(table_name, &col.name);
110 !matches!(strategy, StrategyKind::Skip)
111 })
112 .count()
113 }
114
115 fn rule_matches(&self, rule: &CompiledRule, table_name: &str, column_name: &str) -> bool {
117 if let Some(ref table_pat) = rule.table_pattern {
119 if !table_pat.matches(table_name) && !table_pat.matches(&table_name.to_lowercase()) {
120 return false;
121 }
122 }
123
124 rule.column_pattern.matches(column_name)
126 || rule.column_pattern.matches(&column_name.to_lowercase())
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 fn create_test_schema() -> TableSchema {
135 use crate::schema::{Column, ColumnId, ColumnType, TableId};
136
137 TableSchema {
138 name: "users".to_string(),
139 id: TableId(0),
140 columns: vec![
141 Column {
142 name: "id".to_string(),
143 col_type: ColumnType::Int,
144 ordinal: ColumnId(0),
145 is_primary_key: true,
146 is_nullable: false,
147 },
148 Column {
149 name: "email".to_string(),
150 col_type: ColumnType::Text,
151 ordinal: ColumnId(1),
152 is_primary_key: false,
153 is_nullable: false,
154 },
155 Column {
156 name: "name".to_string(),
157 col_type: ColumnType::Text,
158 ordinal: ColumnId(2),
159 is_primary_key: false,
160 is_nullable: true,
161 },
162 Column {
163 name: "ssn".to_string(),
164 col_type: ColumnType::Text,
165 ordinal: ColumnId(3),
166 is_primary_key: false,
167 is_nullable: true,
168 },
169 ],
170 primary_key: vec![ColumnId(0)],
171 foreign_keys: vec![],
172 indexes: vec![],
173 create_statement: None,
174 }
175 }
176
177 #[test]
178 fn test_wildcard_column_match() {
179 let config = RedactConfig {
180 input: std::path::PathBuf::new(),
181 output: None,
182 dialect: crate::parser::SqlDialect::MySql,
183 rules: vec![Rule {
184 column: "*.email".to_string(),
185 strategy: StrategyKind::Hash { preserve_domain: true },
186 }],
187 default_strategy: StrategyKind::Skip,
188 seed: None,
189 locale: "en".to_string(),
190 tables_filter: None,
191 exclude: vec![],
192 strict: false,
193 progress: false,
194 dry_run: false,
195 };
196
197 let matcher = ColumnMatcher::from_config(&config).unwrap();
198 let schema = create_test_schema();
199
200 let strategies = matcher.get_strategies("users", &schema);
201
202 assert!(matches!(strategies[0], StrategyKind::Skip));
204 assert!(matches!(strategies[1], StrategyKind::Hash { .. }));
205 assert!(matches!(strategies[2], StrategyKind::Skip));
206 assert!(matches!(strategies[3], StrategyKind::Skip));
207 }
208
209 #[test]
210 fn test_exact_column_match() {
211 let config = RedactConfig {
212 input: std::path::PathBuf::new(),
213 output: None,
214 dialect: crate::parser::SqlDialect::MySql,
215 rules: vec![Rule {
216 column: "users.ssn".to_string(),
217 strategy: StrategyKind::Null,
218 }],
219 default_strategy: StrategyKind::Skip,
220 seed: None,
221 locale: "en".to_string(),
222 tables_filter: None,
223 exclude: vec![],
224 strict: false,
225 progress: false,
226 dry_run: false,
227 };
228
229 let matcher = ColumnMatcher::from_config(&config).unwrap();
230
231 let strategy = matcher.get_strategy("users", "ssn");
233 assert!(matches!(strategy, StrategyKind::Null));
234
235 let strategy = matcher.get_strategy("other_table", "ssn");
237 assert!(matches!(strategy, StrategyKind::Skip));
238 }
239
240 #[test]
241 fn test_rule_priority() {
242 let config = RedactConfig {
243 input: std::path::PathBuf::new(),
244 output: None,
245 dialect: crate::parser::SqlDialect::MySql,
246 rules: vec![
247 Rule {
249 column: "admins.email".to_string(),
250 strategy: StrategyKind::Skip,
251 },
252 Rule {
254 column: "*.email".to_string(),
255 strategy: StrategyKind::Hash { preserve_domain: false },
256 },
257 ],
258 default_strategy: StrategyKind::Skip,
259 seed: None,
260 locale: "en".to_string(),
261 tables_filter: None,
262 exclude: vec![],
263 strict: false,
264 progress: false,
265 dry_run: false,
266 };
267
268 let matcher = ColumnMatcher::from_config(&config).unwrap();
269
270 let strategy = matcher.get_strategy("admins", "email");
272 assert!(matches!(strategy, StrategyKind::Skip));
273
274 let strategy = matcher.get_strategy("users", "email");
276 assert!(matches!(strategy, StrategyKind::Hash { .. }));
277 }
278
279 #[test]
280 fn test_count_matches() {
281 let config = RedactConfig {
282 input: std::path::PathBuf::new(),
283 output: None,
284 dialect: crate::parser::SqlDialect::MySql,
285 rules: vec![
286 Rule {
287 column: "*.email".to_string(),
288 strategy: StrategyKind::Hash { preserve_domain: false },
289 },
290 Rule {
291 column: "*.ssn".to_string(),
292 strategy: StrategyKind::Null,
293 },
294 ],
295 default_strategy: StrategyKind::Skip,
296 seed: None,
297 locale: "en".to_string(),
298 tables_filter: None,
299 exclude: vec![],
300 strict: false,
301 progress: false,
302 dry_run: false,
303 };
304
305 let matcher = ColumnMatcher::from_config(&config).unwrap();
306 let schema = create_test_schema();
307
308 assert_eq!(matcher.count_matches("users", &schema), 2);
310 }
311}