Skip to main content

wafrift_encoding/tamper/
config.rs

1//! TOML loading support for tamper strategies.
2
3use std::collections::HashMap;
4
5use super::{TamperError, TamperRegistry};
6
7/// Configuration for tamper strategies loaded from TOML.
8#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
9pub struct StrategyConfig {
10    /// Strategy name
11    pub name: String,
12    /// Whether this strategy is enabled
13    pub enabled: bool,
14    /// Optional context hints (e.g., ["sql", "xss"])
15    pub contexts: Option<Vec<String>>,
16    /// Custom parameters for the strategy
17    pub params: Option<HashMap<String, toml::Value>>,
18}
19
20/// Full configuration for all tamper strategies.
21#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
22pub struct TamperConfig {
23    /// List of strategy configurations
24    pub strategies: Vec<StrategyConfig>,
25}
26
27impl TamperRegistry {
28    /// Loads strategy configurations from a TOML file.
29    ///
30    /// # Errors
31    /// Returns an error if the file cannot be read or parsed.
32    pub fn load_toml<P: AsRef<std::path::Path>>(
33        &mut self,
34        path: P,
35    ) -> Result<TamperConfig, TamperError> {
36        let content = std::fs::read_to_string(path.as_ref())
37            .map_err(|e| TamperError::LoadError(format!("Failed to read file: {e}")))?;
38
39        let config: TamperConfig = toml::from_str(&content)
40            .map_err(|e| TamperError::InvalidConfig(format!("Failed to parse TOML: {e}")))?;
41
42        Ok(config)
43    }
44
45    /// Applies all enabled strategies from a configuration.
46    ///
47    /// Strategies are applied in order of aggressiveness (least to most).
48    pub fn apply_config(&self, payload: &str, config: &TamperConfig) -> Vec<(String, String)> {
49        let mut results = Vec::new();
50
51        for strategy_config in &config.strategies {
52            if !strategy_config.enabled {
53                continue;
54            }
55
56            if let Some(strategy) = self.get(&strategy_config.name) {
57                let context = strategy_config
58                    .contexts
59                    .as_ref()
60                    .and_then(|v| v.first().map(std::string::String::as_str));
61                let result = if let Some(ref params) = strategy_config.params {
62                    strategy.tamper_with_params(payload, context, params)
63                } else {
64                    strategy.tamper(payload, context)
65                };
66                results.push((strategy_config.name.clone(), result));
67            }
68        }
69
70        results
71    }
72}
73
74#[cfg(test)]
75mod tests {
76    use super::*;
77
78    #[test]
79    fn tamper_config_serialization() {
80        let config = TamperConfig {
81            strategies: vec![
82                StrategyConfig {
83                    name: "url_encode".to_string(),
84                    enabled: true,
85                    contexts: Some(vec!["sql".to_string(), "xss".to_string()]),
86                    params: None,
87                },
88                StrategyConfig {
89                    name: "base64".to_string(),
90                    enabled: false,
91                    contexts: None,
92                    params: None,
93                },
94            ],
95        };
96
97        let toml_str = toml::to_string(&config).expect("Failed to serialize config");
98        assert!(toml_str.contains("url_encode"));
99        assert!(toml_str.contains("enabled = true"));
100        assert!(toml_str.contains("enabled = false"));
101
102        let deserialized: TamperConfig =
103            toml::from_str(&toml_str).expect("Failed to deserialize config");
104        assert_eq!(deserialized.strategies.len(), 2);
105        assert!(deserialized.strategies[0].enabled);
106        assert!(!deserialized.strategies[1].enabled);
107    }
108
109    #[test]
110    fn apply_config_filters_disabled() {
111        let registry = TamperRegistry::with_defaults();
112        let config = TamperConfig {
113            strategies: vec![
114                StrategyConfig {
115                    name: "url_encode".to_string(),
116                    enabled: true,
117                    contexts: None,
118                    params: None,
119                },
120                StrategyConfig {
121                    name: "base64".to_string(),
122                    enabled: false,
123                    contexts: None,
124                    params: None,
125                },
126            ],
127        };
128
129        let results = registry.apply_config("test", &config);
130        assert_eq!(results.len(), 1);
131        assert_eq!(results[0].0, "url_encode");
132    }
133
134    #[test]
135    fn apply_config_with_context() {
136        let registry = TamperRegistry::with_defaults();
137        let config = TamperConfig {
138            strategies: vec![StrategyConfig {
139                name: "sql_comment".to_string(),
140                enabled: true,
141                contexts: Some(vec!["sql".to_string()]),
142                params: None,
143            }],
144        };
145
146        let results = registry.apply_config("SELECT * FROM", &config);
147        assert_eq!(results.len(), 1);
148        assert!(results[0].1.contains("/**/"));
149    }
150
151    #[test]
152    fn strategy_config_roundtrip() {
153        let config_str = r#"
154[[strategies]]
155name = "url_encode"
156enabled = true
157contexts = ["sql", "xss"]
158"#;
159
160        let config: TamperConfig = toml::from_str(config_str).expect("Failed to parse TOML");
161        assert_eq!(config.strategies.len(), 1);
162        assert_eq!(config.strategies[0].name, "url_encode");
163        assert!(config.strategies[0].enabled);
164        assert_eq!(
165            config.strategies[0].contexts,
166            Some(vec!["sql".to_string(), "xss".to_string()])
167        );
168    }
169
170    #[test]
171    fn load_toml_from_strategies_d() {
172        let mut registry = TamperRegistry::with_defaults();
173        let path = std::path::Path::new(concat!(
174            env!("CARGO_MANIFEST_DIR"),
175            "/../../strategies.d/core.toml"
176        ));
177
178        if path.exists() {
179            let config = registry.load_toml(path).expect("Failed to load core.toml");
180            let has_url_encode = config
181                .strategies
182                .iter()
183                .any(|s| s.name == "url_encode" && s.enabled);
184            assert!(has_url_encode, "core.toml should have url_encode enabled");
185        }
186    }
187
188    #[test]
189    fn tamper_error_invalid_toml() {
190        let mut registry = TamperRegistry::with_defaults();
191        let invalid_toml = "not valid toml [[";
192
193        let temp_file = std::env::temp_dir().join("invalid_toml_test.toml");
194        std::fs::write(&temp_file, invalid_toml).unwrap();
195
196        let result = registry.load_toml(&temp_file);
197        assert!(matches!(result, Err(TamperError::InvalidConfig(_))));
198
199        std::fs::remove_file(&temp_file).ok();
200    }
201
202    #[test]
203    fn tamper_error_missing_file() {
204        let mut registry = TamperRegistry::with_defaults();
205        let result = registry.load_toml("/nonexistent/path/file.toml");
206        assert!(matches!(result, Err(TamperError::LoadError(_))));
207    }
208
209    #[test]
210    fn layered_tamper_chain() {
211        let registry = TamperRegistry::with_defaults();
212        let config = TamperConfig {
213            strategies: vec![
214                StrategyConfig {
215                    name: "case_alternation".to_string(),
216                    enabled: true,
217                    contexts: None,
218                    params: None,
219                },
220                StrategyConfig {
221                    name: "url_encode".to_string(),
222                    enabled: true,
223                    contexts: None,
224                    params: None,
225                },
226            ],
227        };
228
229        let results = registry.apply_config("select <", &config);
230        assert_eq!(results.len(), 2);
231
232        assert!(results.iter().any(|(n, _)| n == "case_alternation"));
233        assert!(results.iter().any(|(n, _)| n == "url_encode"));
234
235        let url_result = results.iter().find(|(n, _)| n == "url_encode").unwrap();
236        assert!(url_result.1.contains('%'));
237    }
238
239    #[test]
240    fn tamper_strategy_trait_object_safety() {
241        let strategies: Vec<Box<dyn super::super::TamperStrategy>> = vec![
242            Box::new(super::super::UrlEncodeTamper),
243            Box::new(super::super::Base64Tamper),
244            Box::new(super::super::CaseAlternationTamper),
245        ];
246
247        for strategy in &strategies {
248            let result = strategy.tamper("test", None);
249            assert!(!result.is_empty());
250            assert!(strategy.aggressiveness() >= 0.0 && strategy.aggressiveness() <= 1.0);
251        }
252    }
253
254    #[test]
255    fn custom_strategy_params() {
256        let config = StrategyConfig {
257            name: "custom".to_string(),
258            enabled: true,
259            contexts: None,
260            params: {
261                let mut map = std::collections::HashMap::new();
262                map.insert("level".to_string(), toml::Value::Integer(5));
263                map.insert(
264                    "prefix".to_string(),
265                    toml::Value::String("test_".to_string()),
266                );
267                Some(map)
268            },
269        };
270
271        assert!(config.params.is_some());
272        let params = config.params.as_ref().unwrap();
273        assert_eq!(params.get("level").unwrap().as_integer(), Some(5));
274        assert_eq!(params.get("prefix").unwrap().as_str(), Some("test_"));
275    }
276}