1use crate::redactor::config::{RedactConfig, Rule};
4use crate::redactor::StrategyKind;
5use crate::schema::TableSchema;
6use glob::Pattern;
7
8#[derive(Debug)]
10pub struct ColumnMatcher {
11 rules: Vec<CompiledRule>,
13 default_strategy: StrategyKind,
15}
16
17#[derive(Debug)]
19struct CompiledRule {
20 table_pattern: Option<Pattern>,
22 column_pattern: Pattern,
24 strategy: StrategyKind,
26}
27
28impl ColumnMatcher {
29 pub fn from_config(config: &RedactConfig) -> anyhow::Result<Self> {
31 let mut rules = Vec::with_capacity(config.rules.len());
32
33 for rule in &config.rules {
34 let compiled = Self::compile_rule(rule)?;
35 rules.push(compiled);
36 }
37
38 Ok(Self {
39 rules,
40 default_strategy: config.default_strategy.clone(),
41 })
42 }
43
44 fn compile_rule(rule: &Rule) -> anyhow::Result<CompiledRule> {
46 let pattern = &rule.column;
47
48 let (table_pattern, column_pattern) = if let Some(dot_pos) = pattern.find('.') {
50 let table_part = &pattern[..dot_pos];
51 let column_part = &pattern[dot_pos + 1..];
52
53 let table_pat = if table_part == "*" {
55 None
56 } else {
57 Some(Pattern::new(table_part).map_err(|e| {
58 anyhow::anyhow!("Invalid table pattern '{}': {}", table_part, e)
59 })?)
60 };
61
62 let col_pat = Pattern::new(column_part)
63 .map_err(|e| anyhow::anyhow!("Invalid column pattern '{}': {}", column_part, e))?;
64
65 (table_pat, col_pat)
66 } else {
67 let col_pat = Pattern::new(pattern)
69 .map_err(|e| anyhow::anyhow!("Invalid column pattern '{}': {}", pattern, e))?;
70 (None, col_pat)
71 };
72
73 Ok(CompiledRule {
74 table_pattern,
75 column_pattern,
76 strategy: rule.strategy.clone(),
77 })
78 }
79
80 pub fn get_strategy(&self, table_name: &str, column_name: &str) -> StrategyKind {
82 for rule in &self.rules {
84 if self.rule_matches(rule, table_name, column_name) {
85 return rule.strategy.clone();
86 }
87 }
88
89 self.default_strategy.clone()
91 }
92
93 pub fn get_strategies(&self, table_name: &str, table: &TableSchema) -> Vec<StrategyKind> {
95 table
96 .columns
97 .iter()
98 .map(|col| self.get_strategy(table_name, &col.name))
99 .collect()
100 }
101
102 pub fn count_matches(&self, table_name: &str, table: &TableSchema) -> usize {
104 table
105 .columns
106 .iter()
107 .filter(|col| {
108 let strategy = self.get_strategy(table_name, &col.name);
109 !matches!(strategy, StrategyKind::Skip)
110 })
111 .count()
112 }
113
114 fn rule_matches(&self, rule: &CompiledRule, table_name: &str, column_name: &str) -> bool {
116 if let Some(ref table_pat) = rule.table_pattern {
118 if !table_pat.matches(table_name) && !table_pat.matches(&table_name.to_lowercase()) {
119 return false;
120 }
121 }
122
123 rule.column_pattern.matches(column_name)
125 || rule.column_pattern.matches(&column_name.to_lowercase())
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 fn create_test_schema() -> TableSchema {
134 use crate::schema::{Column, ColumnId, ColumnType, TableId};
135
136 TableSchema {
137 name: "users".to_string(),
138 id: TableId(0),
139 columns: vec![
140 Column {
141 name: "id".to_string(),
142 col_type: ColumnType::Int,
143 ordinal: ColumnId(0),
144 is_primary_key: true,
145 is_nullable: false,
146 },
147 Column {
148 name: "email".to_string(),
149 col_type: ColumnType::Text,
150 ordinal: ColumnId(1),
151 is_primary_key: false,
152 is_nullable: false,
153 },
154 Column {
155 name: "name".to_string(),
156 col_type: ColumnType::Text,
157 ordinal: ColumnId(2),
158 is_primary_key: false,
159 is_nullable: true,
160 },
161 Column {
162 name: "ssn".to_string(),
163 col_type: ColumnType::Text,
164 ordinal: ColumnId(3),
165 is_primary_key: false,
166 is_nullable: true,
167 },
168 ],
169 primary_key: vec![ColumnId(0)],
170 foreign_keys: vec![],
171 indexes: vec![],
172 create_statement: None,
173 }
174 }
175
176 #[test]
177 fn test_wildcard_column_match() {
178 let config = RedactConfig {
179 input: std::path::PathBuf::new(),
180 output: None,
181 dialect: crate::parser::SqlDialect::MySql,
182 rules: vec![Rule {
183 column: "*.email".to_string(),
184 strategy: StrategyKind::Hash {
185 preserve_domain: true,
186 },
187 }],
188 default_strategy: StrategyKind::Skip,
189 seed: None,
190 locale: "en".to_string(),
191 tables_filter: None,
192 exclude: vec![],
193 strict: false,
194 progress: false,
195 dry_run: false,
196 };
197
198 let matcher = ColumnMatcher::from_config(&config).unwrap();
199 let schema = create_test_schema();
200
201 let strategies = matcher.get_strategies("users", &schema);
202
203 assert!(matches!(strategies[0], StrategyKind::Skip));
205 assert!(matches!(strategies[1], StrategyKind::Hash { .. }));
206 assert!(matches!(strategies[2], StrategyKind::Skip));
207 assert!(matches!(strategies[3], StrategyKind::Skip));
208 }
209
210 #[test]
211 fn test_exact_column_match() {
212 let config = RedactConfig {
213 input: std::path::PathBuf::new(),
214 output: None,
215 dialect: crate::parser::SqlDialect::MySql,
216 rules: vec![Rule {
217 column: "users.ssn".to_string(),
218 strategy: StrategyKind::Null,
219 }],
220 default_strategy: StrategyKind::Skip,
221 seed: None,
222 locale: "en".to_string(),
223 tables_filter: None,
224 exclude: vec![],
225 strict: false,
226 progress: false,
227 dry_run: false,
228 };
229
230 let matcher = ColumnMatcher::from_config(&config).unwrap();
231
232 let strategy = matcher.get_strategy("users", "ssn");
234 assert!(matches!(strategy, StrategyKind::Null));
235
236 let strategy = matcher.get_strategy("other_table", "ssn");
238 assert!(matches!(strategy, StrategyKind::Skip));
239 }
240
241 #[test]
242 fn test_rule_priority() {
243 let config = RedactConfig {
244 input: std::path::PathBuf::new(),
245 output: None,
246 dialect: crate::parser::SqlDialect::MySql,
247 rules: vec![
248 Rule {
250 column: "admins.email".to_string(),
251 strategy: StrategyKind::Skip,
252 },
253 Rule {
255 column: "*.email".to_string(),
256 strategy: StrategyKind::Hash {
257 preserve_domain: false,
258 },
259 },
260 ],
261 default_strategy: StrategyKind::Skip,
262 seed: None,
263 locale: "en".to_string(),
264 tables_filter: None,
265 exclude: vec![],
266 strict: false,
267 progress: false,
268 dry_run: false,
269 };
270
271 let matcher = ColumnMatcher::from_config(&config).unwrap();
272
273 let strategy = matcher.get_strategy("admins", "email");
275 assert!(matches!(strategy, StrategyKind::Skip));
276
277 let strategy = matcher.get_strategy("users", "email");
279 assert!(matches!(strategy, StrategyKind::Hash { .. }));
280 }
281
282 #[test]
283 fn test_count_matches() {
284 let config = RedactConfig {
285 input: std::path::PathBuf::new(),
286 output: None,
287 dialect: crate::parser::SqlDialect::MySql,
288 rules: vec![
289 Rule {
290 column: "*.email".to_string(),
291 strategy: StrategyKind::Hash {
292 preserve_domain: false,
293 },
294 },
295 Rule {
296 column: "*.ssn".to_string(),
297 strategy: StrategyKind::Null,
298 },
299 ],
300 default_strategy: StrategyKind::Skip,
301 seed: None,
302 locale: "en".to_string(),
303 tables_filter: None,
304 exclude: vec![],
305 strict: false,
306 progress: false,
307 dry_run: false,
308 };
309
310 let matcher = ColumnMatcher::from_config(&config).unwrap();
311 let schema = create_test_schema();
312
313 assert_eq!(matcher.count_matches("users", &schema), 2);
315 }
316}