Skip to main content

simple_waf_scanner/
scanner.rs

1use crate::{
2    config::Config,
3    evasion,
4    extractor::DataExtractor,
5    fingerprints::{DetectionResponse, WafDetector},
6    http::{build_client, send_request},
7    payloads::PayloadManager,
8    types::{Finding, ScanResults, ScanSummary},
9};
10use std::collections::HashSet;
11use std::sync::Arc;
12use std::time::Instant;
13use tokio::sync::Semaphore;
14use tokio::time::{sleep, Duration};
15
16/// WAF scanner
17pub struct Scanner {
18    config: Config,
19    client: reqwest::Client,
20    payload_manager: PayloadManager,
21    waf_detector: WafDetector,
22    data_extractor: DataExtractor,
23}
24
25impl Scanner {
26    /// Create a new scanner
27    pub async fn new(config: Config) -> crate::error::Result<Self> {
28        config.validate()?;
29
30        let client = build_client(&config)?;
31
32        let payload_manager = if let Some(ref payload_file) = config.payload_file {
33            tracing::info!("Loading custom payloads from: {}", payload_file);
34            PayloadManager::from_file(payload_file).await?
35        } else {
36            tracing::info!("Loading default embedded payloads");
37            PayloadManager::with_defaults()?
38        };
39
40        let waf_detector = WafDetector::new()?;
41        let data_extractor = DataExtractor::new();
42
43        Ok(Self {
44            config,
45            client,
46            payload_manager,
47            waf_detector,
48            data_extractor,
49        })
50    }
51
52    /// Perform the WAF bypass scan
53    #[tracing::instrument(skip(self), fields(target = %self.config.target))]
54    pub async fn scan(&self) -> crate::error::Result<ScanResults> {
55        let start_time = Instant::now();
56
57        tracing::info!("Starting WAF scan on {}", self.config.target);
58
59        // Step 1: Detect WAF
60        let waf_detected = self.detect_waf().await?;
61
62        if let Some(ref waf_name) = waf_detected {
63            tracing::info!("Detected WAF: {}", waf_name);
64        } else {
65            tracing::info!("No WAF detected");
66        }
67
68        // Step 2: Run payload tests
69        let mut results = ScanResults::new(self.config.target.clone(), waf_detected);
70        let findings = self.test_payloads().await?;
71
72        for finding in findings {
73            results.add_finding(finding);
74        }
75
76        // Step 3: Calculate summary
77        results.sort_by_severity();
78
79        let techniques_used: HashSet<_> = results
80            .findings
81            .iter()
82            .filter_map(|f| f.technique_used.as_ref())
83            .collect();
84
85        results.summary = ScanSummary {
86            total_payloads: self.payload_manager.payloads().len(),
87            successful_bypasses: results.findings.len(),
88            techniques_effective: techniques_used.len(),
89            duration_secs: start_time.elapsed().as_secs_f64(),
90        };
91
92        tracing::info!(
93            "Scan complete. Found {} successful bypasses in {:.2}s",
94            results.summary.successful_bypasses,
95            results.summary.duration_secs
96        );
97
98        Ok(results)
99    }
100
101    /// Detect WAF by sending a baseline request
102    async fn detect_waf(&self) -> crate::error::Result<Option<String>> {
103        tracing::debug!("Sending baseline request for WAF detection");
104
105        let response = send_request(&self.client, &self.config.target, None)
106            .await
107            .map_err(|e| {
108                tracing::error!("Connection failed: {}", e);
109                e // Just pass through the reqwest::Error
110            })?;
111
112        // Log HTTP version information
113        tracing::info!(
114            "Target {} is using HTTP version: {}",
115            self.config.target,
116            response.http_version
117        );
118
119        if response.http_version.contains("HTTP/2") {
120            tracing::info!("✓ HTTP/2 protocol detected - production-ready configuration active");
121        } else {
122            tracing::warn!("⚠ HTTP/1.x detected - some HTTP/2 tests may not apply");
123        }
124
125        let detection_response = DetectionResponse::new(
126            response.status_code,
127            response.headers,
128            response.body,
129            response.cookies,
130        );
131
132        Ok(self.waf_detector.detect(&detection_response))
133    }
134
135    /// Test all payloads with evasion techniques
136    async fn test_payloads(&self) -> crate::error::Result<Vec<Finding>> {
137        let payloads = self.payload_manager.payloads();
138        let semaphore = Arc::new(Semaphore::new(self.config.concurrency));
139        let mut tasks = Vec::new();
140
141        tracing::info!("Testing {} payloads", payloads.len());
142
143        for payload in payloads {
144            for payload_test in &payload.payloads {
145                // Apply all evasion techniques
146                let technique_variants = evasion::apply_all_techniques(
147                    &payload_test.value,
148                    self.config.enabled_techniques.as_deref(),
149                );
150
151                for (technique_name, transformed_payload) in technique_variants {
152                    let sem = semaphore.clone();
153                    let client = self.client.clone();
154                    let target = self.config.target.clone();
155                    let delay_ms = self.config.delay_ms;
156                    let payload_id = payload.id.clone();
157                    let severity = payload.info.severity;
158                    let category = payload.info.category.clone();
159                    let description = payload.info.description.clone();
160                    let matchers = payload.matchers.clone();
161                    let extractor = self.data_extractor.clone();
162
163                    let task = tokio::spawn(async move {
164                        let _permit = sem.acquire().await.unwrap();
165
166                        // Rate limiting delay
167                        if delay_ms > 0 {
168                            sleep(Duration::from_millis(delay_ms)).await;
169                        }
170
171                        // Send request with payload as query parameter
172                        let response =
173                            send_request(&client, &target, Some(("test", &transformed_payload)))
174                                .await;
175
176                        match response {
177                            Ok(resp) => {
178                                // Check if payload matched
179                                let matched = check_matchers(&resp, &matchers);
180
181                                if matched {
182                                    tracing::debug!(
183                                        "Payload {} matched with technique: {} (HTTP version: {})",
184                                        payload_id,
185                                        technique_name,
186                                        resp.http_version
187                                    );
188
189                                    // Extract sensitive data from response
190                                    let extracted_data = extractor.extract(
191                                        &resp.body,
192                                        &resp.headers,
193                                        &resp.cookies,
194                                    );
195
196                                    Some(Finding {
197                                        payload_id: payload_id.clone(),
198                                        severity,
199                                        category: category.clone(),
200                                        owasp_category: crate::types::OwaspCategory::from_attack_type(&category),
201                                        payload_value: transformed_payload,
202                                        technique_used: if technique_name == "Original" {
203                                            None
204                                        } else {
205                                            Some(technique_name)
206                                        },
207                                        response_status: resp.status_code,
208                                        description,
209                                        http_version: Some(resp.http_version),
210                                        extracted_data: if extracted_data.has_data() {
211                                            Some(extracted_data)
212                                        } else {
213                                            None
214                                        },
215                                    })
216                                } else {
217                                    None
218                                }
219                            }
220                            Err(e) => {
221                                tracing::warn!("Request failed for payload {}: {}", payload_id, e);
222                                None
223                            }
224                        }
225                    });
226
227                    tasks.push(task);
228                }
229            }
230        }
231
232        // Wait for all tasks to complete
233        let results = futures::future::join_all(tasks).await;
234
235        // Collect findings
236        let findings: Vec<Finding> = results
237            .into_iter()
238            .filter_map(|r| r.ok())
239            .flatten()
240            .collect();
241
242        Ok(findings)
243    }
244}
245
246/// Check if response matches any of the matchers
247fn check_matchers(
248    response: &crate::http::HttpResponse,
249    matchers: &[crate::payloads::Matcher],
250) -> bool {
251    for matcher in matchers {
252        match matcher.matcher_type.as_str() {
253            "response_body" => {
254                if matcher.condition == "contains" {
255                    for pattern in &matcher.patterns {
256                        if response.body.contains(pattern) {
257                            return true;
258                        }
259                    }
260                } else if matcher.condition == "not_contains" {
261                    let mut all_not_found = true;
262                    for pattern in &matcher.patterns {
263                        if response.body.contains(pattern) {
264                            all_not_found = false;
265                            break;
266                        }
267                    }
268                    if all_not_found {
269                        return true;
270                    }
271                }
272            }
273            "response_time" => {
274                if matcher.condition == "greater_than" {
275                    if let Some(threshold) = matcher.patterns.first() {
276                        if let Ok(threshold_ms) = threshold.parse::<u64>() {
277                            if response.response_time_ms > threshold_ms {
278                                return true;
279                            }
280                        }
281                    }
282                }
283            }
284            "response_status" => {
285                if matcher.condition == "equals" {
286                    for pattern in &matcher.patterns {
287                        if let Ok(expected_status) = pattern.parse::<u16>() {
288                            if response.status_code == expected_status {
289                                return true;
290                            }
291                        }
292                    }
293                } else if matcher.condition == "not_equals" {
294                    let mut all_different = true;
295                    for pattern in &matcher.patterns {
296                        if let Ok(expected_status) = pattern.parse::<u16>() {
297                            if response.status_code == expected_status {
298                                all_different = false;
299                                break;
300                            }
301                        }
302                    }
303                    if all_different {
304                        return true;
305                    }
306                }
307            }
308            "response_header" => {
309                if matcher.condition == "contains" {
310                    for pattern in &matcher.patterns {
311                        for (_, header_value) in &response.headers {
312                            if header_value.contains(pattern) {
313                                return true;
314                            }
315                        }
316                    }
317                } else if matcher.condition == "not_contains" {
318                    let mut found = false;
319                    for pattern in &matcher.patterns {
320                        for (_, header_value) in &response.headers {
321                            if header_value.contains(pattern) {
322                                found = true;
323                                break;
324                            }
325                        }
326                        if found {
327                            break;
328                        }
329                    }
330                    if !found {
331                        return true;
332                    }
333                }
334            }
335            "ai_refusal_bypass" => {
336                // Check for absence of AI refusal patterns (indicates successful jailbreak)
337                if matcher.condition == "not_contains" {
338                    let mut all_not_found = true;
339                    for pattern in &matcher.patterns {
340                        if response.body.to_lowercase().contains(&pattern.to_lowercase()) {
341                            all_not_found = false;
342                            break;
343                        }
344                    }
345                    if all_not_found {
346                        return true;
347                    }
348                }
349            }
350            "system_prompt_leak" => {
351                // Detect system prompt leakage indicators
352                if matcher.condition == "contains" {
353                    for pattern in &matcher.patterns {
354                        if response.body.contains(pattern) {
355                            return true;
356                        }
357                    }
358                }
359            }
360            "jailbreak_success" => {
361                // Detect jailbreak success indicators (e.g., compliance with malicious request)
362                if matcher.condition == "contains" {
363                    for pattern in &matcher.patterns {
364                        if response.body.to_lowercase().contains(&pattern.to_lowercase()) {
365                            return true;
366                        }
367                    }
368                }
369            }
370            "response_json" => {
371                // Validate JSON structure for improper output handling
372                if matcher.condition == "valid" {
373                    // Check if response body is valid JSON
374                    if serde_json::from_str::<serde_json::Value>(&response.body).is_ok() {
375                        return true;
376                    }
377                } else if matcher.condition == "contains_field" {
378                    // Check if JSON contains specific fields
379                    if let Ok(json) = serde_json::from_str::<serde_json::Value>(&response.body) {
380                        for pattern in &matcher.patterns {
381                            if json.get(pattern).is_some() {
382                                return true;
383                            }
384                        }
385                    }
386                }
387            }
388            _ => {
389                tracing::warn!("Unknown matcher type: {}", matcher.matcher_type);
390            }
391        }
392    }
393
394    false
395}