sklears_core/
unsafe_audit.rs

1/// Unsafe code auditing and minimization utilities
2///
3/// This module provides tools for auditing, tracking, and minimizing unsafe code usage
4/// in the sklears ecosystem, with a focus on safety and correctness.
5use crate::error::SklearsError;
6use std::collections::{HashMap, HashSet};
7use std::fs;
8use std::path::{Path, PathBuf};
9
10/// Result type alias for unsafe audit operations
11pub type Result<T> = std::result::Result<T, SklearsError>;
12
13/// Configuration for unsafe code auditing
14#[derive(Debug, Clone)]
15pub struct UnsafeAuditConfig {
16    /// Paths to scan for unsafe code
17    pub scan_paths: Vec<PathBuf>,
18    /// Paths to exclude from auditing
19    pub exclude_paths: Vec<PathBuf>,
20    /// Maximum allowed unsafe blocks per file
21    pub max_unsafe_per_file: usize,
22    /// Whether to check for documented justifications
23    pub require_justification: bool,
24    /// Whether to flag all unsafe code as errors
25    pub strict_mode: bool,
26    /// Known safe patterns to allow
27    pub allowed_patterns: Vec<UnsafePattern>,
28}
29
30/// Pattern for safe unsafe code usage
31#[derive(Debug, Clone)]
32pub struct UnsafePattern {
33    /// Name/description of the pattern
34    pub name: String,
35    /// Function/method signatures that are considered safe
36    pub signatures: Vec<String>,
37    /// Justification for why this pattern is safe
38    pub justification: String,
39    /// Required preconditions for safety
40    pub preconditions: Vec<String>,
41}
42
43/// Result of unsafe code audit
44#[derive(Debug, Clone)]
45pub struct UnsafeAuditReport {
46    /// Whether the audit passed all checks
47    pub passed: bool,
48    /// Total number of files scanned
49    pub files_scanned: usize,
50    /// Number of files with unsafe code
51    pub files_with_unsafe: usize,
52    /// Total number of unsafe blocks found
53    pub total_unsafe_blocks: usize,
54    /// Unsafe code findings per file
55    pub findings: HashMap<PathBuf, Vec<UnsafeFinding>>,
56    /// Summary statistics
57    pub summary: UnsafeSummary,
58    /// Recommendations for improvement
59    pub recommendations: Vec<SafetyRecommendation>,
60}
61
62/// Individual unsafe code finding
63#[derive(Debug, Clone)]
64pub struct UnsafeFinding {
65    /// File path where unsafe code was found
66    pub file: PathBuf,
67    /// Line number of the unsafe block
68    pub line: usize,
69    /// Column number (if available)
70    pub column: Option<usize>,
71    /// Type of unsafe operation
72    pub unsafe_type: UnsafeType,
73    /// The actual unsafe code snippet
74    pub code_snippet: String,
75    /// Justification provided (if any)
76    pub justification: Option<String>,
77    /// Whether this pattern is known to be safe
78    pub is_known_safe: bool,
79    /// Severity of the safety concern
80    pub severity: SafetySeverity,
81    /// Suggested alternatives or improvements
82    pub suggestions: Vec<String>,
83}
84
85/// Type of unsafe operation
86#[derive(Debug, Clone, PartialEq, Eq, Hash)]
87pub enum UnsafeType {
88    /// Raw pointer dereferencing
89    RawPointerDeref,
90    /// Calling unsafe functions
91    UnsafeFunctionCall,
92    /// Mutable static access
93    MutableStatic,
94    /// Union field access
95    UnionFieldAccess,
96    /// Transmute operations
97    Transmute,
98    /// Inline assembly
99    InlineAssembly,
100    /// Generic unsafe block
101    UnsafeBlock,
102}
103
104/// Severity of safety concerns
105#[derive(Debug, Clone, PartialEq, Eq, Hash)]
106pub enum SafetySeverity {
107    /// Informational, pattern is known safe
108    Info,
109    /// Low risk, acceptable with documentation
110    Low,
111    /// Medium risk that should be justified
112    Medium,
113    /// High risk that should be reviewed
114    High,
115    /// Critical safety issue that must be addressed
116    Critical,
117}
118
119impl PartialOrd for SafetySeverity {
120    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
121        Some(self.cmp(other))
122    }
123}
124
125impl Ord for SafetySeverity {
126    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
127        use SafetySeverity::*;
128        match (self, other) {
129            (Info, Info) => std::cmp::Ordering::Equal,
130            (Info, _) => std::cmp::Ordering::Less,
131            (_, Info) => std::cmp::Ordering::Greater,
132            (Low, Low) => std::cmp::Ordering::Equal,
133            (Low, _) => std::cmp::Ordering::Less,
134            (_, Low) => std::cmp::Ordering::Greater,
135            (Medium, Medium) => std::cmp::Ordering::Equal,
136            (Medium, _) => std::cmp::Ordering::Less,
137            (_, Medium) => std::cmp::Ordering::Greater,
138            (High, High) => std::cmp::Ordering::Equal,
139            (High, Critical) => std::cmp::Ordering::Less,
140            (Critical, High) => std::cmp::Ordering::Greater,
141            (Critical, Critical) => std::cmp::Ordering::Equal,
142        }
143    }
144}
145
146/// Summary statistics for unsafe code audit
147#[derive(Debug, Clone)]
148pub struct UnsafeSummary {
149    /// Breakdown by unsafe operation type
150    pub types_breakdown: HashMap<UnsafeType, usize>,
151    /// Breakdown by severity
152    pub severity_breakdown: HashMap<SafetySeverity, usize>,
153    /// Files with the most unsafe code
154    pub top_unsafe_files: Vec<(PathBuf, usize)>,
155    /// Common unsafe patterns found
156    pub common_patterns: Vec<String>,
157}
158
159/// Safety improvement recommendation
160#[derive(Debug, Clone)]
161pub struct SafetyRecommendation {
162    /// Type of recommendation
163    pub recommendation_type: RecommendationType,
164    /// Description of the recommendation
165    pub description: String,
166    /// Files that would benefit from this recommendation
167    pub affected_files: Vec<PathBuf>,
168    /// Estimated effort to implement
169    pub effort: EffortLevel,
170    /// Safety impact of implementing this recommendation
171    pub safety_impact: SafetyImpact,
172}
173
174/// Type of safety recommendation
175#[derive(Debug, Clone)]
176pub enum RecommendationType {
177    /// Replace unsafe code with safe alternatives
178    ReplaceWithSafe,
179    /// Add better documentation/justification
180    ImproveDocumentation,
181    /// Reduce scope of unsafe operations
182    ReduceScope,
183    /// Add safety assertions/checks
184    AddSafetyChecks,
185    /// Refactor to eliminate unsafe code
186    Refactor,
187    /// Use safer abstractions
188    UseSaferAbstractions,
189}
190
191/// Effort level for implementing recommendations
192#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
193pub enum EffortLevel {
194    /// Minimal effort (< 1 hour)
195    Minimal,
196    /// Low effort (1-4 hours)
197    Low,
198    /// Medium effort (4-16 hours)
199    Medium,
200    /// High effort (16+ hours)
201    High,
202}
203
204/// Safety impact of recommendations
205#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
206pub enum SafetyImpact {
207    /// Critical safety improvement
208    Critical,
209    /// High safety improvement
210    High,
211    /// Medium safety improvement
212    Medium,
213    /// Low safety improvement
214    Low,
215}
216
217impl Default for UnsafeAuditConfig {
218    fn default() -> Self {
219        Self {
220            scan_paths: vec![PathBuf::from("src")],
221            exclude_paths: vec![
222                PathBuf::from("target"),
223                PathBuf::from("benches"),
224                PathBuf::from("examples"),
225            ],
226            max_unsafe_per_file: 5,
227            require_justification: true,
228            strict_mode: false,
229            allowed_patterns: Self::default_safe_patterns(),
230        }
231    }
232}
233
234impl UnsafeAuditConfig {
235    /// Get default set of known safe patterns
236    fn default_safe_patterns() -> Vec<UnsafePattern> {
237        vec![
238            UnsafePattern {
239                name: "SIMD Operations".to_string(),
240                signatures: vec!["std::simd::".to_string(), "std::arch::".to_string()],
241                justification: "SIMD operations are generally safe when used correctly".to_string(),
242                preconditions: vec![
243                    "Input arrays are properly aligned".to_string(),
244                    "Array bounds are checked".to_string(),
245                ],
246            },
247            UnsafePattern {
248                name: "Slice from Raw Parts".to_string(),
249                signatures: vec![
250                    "std::slice::from_raw_parts".to_string(),
251                    "std::slice::from_raw_parts_mut".to_string(),
252                ],
253                justification: "Safe when pointer and length are valid".to_string(),
254                preconditions: vec![
255                    "Pointer is non-null and properly aligned".to_string(),
256                    "Length is accurate and doesn't overflow".to_string(),
257                    "Memory is valid for the lifetime".to_string(),
258                ],
259            },
260            UnsafePattern {
261                name: "FFI Bindings".to_string(),
262                signatures: vec!["extern".to_string()],
263                justification: "FFI calls to well-tested C libraries".to_string(),
264                preconditions: vec![
265                    "C library is memory-safe".to_string(),
266                    "Parameters are validated".to_string(),
267                    "Return values are checked".to_string(),
268                ],
269            },
270        ]
271    }
272}
273
274/// Main unsafe code auditor
275pub struct UnsafeAuditor {
276    config: UnsafeAuditConfig,
277}
278
279impl UnsafeAuditor {
280    /// Create a new auditor with default configuration
281    pub fn new() -> Self {
282        Self {
283            config: UnsafeAuditConfig::default(),
284        }
285    }
286
287    /// Create a new auditor with custom configuration
288    pub fn with_config(config: UnsafeAuditConfig) -> Self {
289        Self { config }
290    }
291
292    /// Run complete unsafe code audit
293    pub fn audit<P: AsRef<Path>>(&self, root_path: P) -> Result<UnsafeAuditReport> {
294        let root_path = root_path.as_ref();
295        let mut findings = HashMap::new();
296        let mut files_scanned = 0;
297        let mut total_unsafe_blocks = 0;
298
299        // Scan all Rust files in the specified paths
300        for scan_path in &self.config.scan_paths {
301            let full_path = root_path.join(scan_path);
302            if full_path.exists() {
303                self.scan_directory(
304                    &full_path,
305                    &mut findings,
306                    &mut files_scanned,
307                    &mut total_unsafe_blocks,
308                )?;
309            }
310        }
311
312        let files_with_unsafe = findings.len();
313        let passed = self.evaluate_audit_results(&findings);
314        let summary = self.generate_summary(&findings);
315        let recommendations = self.generate_recommendations(&findings);
316
317        Ok(UnsafeAuditReport {
318            passed,
319            files_scanned,
320            files_with_unsafe,
321            total_unsafe_blocks,
322            findings,
323            summary,
324            recommendations,
325        })
326    }
327
328    /// Scan a directory for unsafe code
329    fn scan_directory(
330        &self,
331        dir: &Path,
332        findings: &mut HashMap<PathBuf, Vec<UnsafeFinding>>,
333        files_scanned: &mut usize,
334        total_unsafe: &mut usize,
335    ) -> Result<()> {
336        if self.should_exclude(dir) {
337            return Ok(());
338        }
339
340        let entries = fs::read_dir(dir)
341            .map_err(|e| SklearsError::InvalidInput(format!("Failed to read directory: {e}")))?;
342
343        for entry in entries {
344            let entry = entry
345                .map_err(|e| SklearsError::InvalidInput(format!("Failed to read entry: {e}")))?;
346            let path = entry.path();
347
348            if path.is_dir() {
349                self.scan_directory(&path, findings, files_scanned, total_unsafe)?;
350            } else if path.extension().map(|ext| ext == "rs").unwrap_or(false)
351                && !self.should_exclude(&path)
352            {
353                *files_scanned += 1;
354                let file_findings = self.scan_file(&path)?;
355                *total_unsafe += file_findings.len();
356                if !file_findings.is_empty() {
357                    findings.insert(path, file_findings);
358                }
359            }
360        }
361
362        Ok(())
363    }
364
365    /// Scan a single file for unsafe code
366    fn scan_file(&self, file_path: &Path) -> Result<Vec<UnsafeFinding>> {
367        let content = fs::read_to_string(file_path)
368            .map_err(|e| SklearsError::InvalidInput(format!("Failed to read file: {e}")))?;
369
370        let mut findings = Vec::new();
371        let lines: Vec<&str> = content.lines().collect();
372
373        for (line_num, line) in lines.iter().enumerate() {
374            if let Some(finding) = self.analyze_line(file_path, line_num + 1, line) {
375                findings.push(finding);
376            }
377        }
378
379        // Check for block-level unsafe patterns
380        findings.extend(self.analyze_unsafe_blocks(file_path, &content)?);
381
382        Ok(findings)
383    }
384
385    /// Analyze a single line for unsafe patterns
386    fn analyze_line(&self, file_path: &Path, line_num: usize, line: &str) -> Option<UnsafeFinding> {
387        let trimmed = line.trim();
388
389        // Check for unsafe keyword
390        if trimmed.starts_with("unsafe") {
391            let unsafe_type = self.determine_unsafe_type(line);
392            let severity = self.assess_severity(&unsafe_type, line);
393            let is_known_safe = self.is_known_safe_pattern(line);
394            let justification = self.extract_justification(line);
395            let suggestions = self.generate_suggestions(&unsafe_type, line);
396
397            Some(UnsafeFinding {
398                file: file_path.to_path_buf(),
399                line: line_num,
400                column: line.find("unsafe"),
401                unsafe_type,
402                code_snippet: line.to_string(),
403                justification,
404                is_known_safe,
405                severity,
406                suggestions,
407            })
408        } else {
409            None
410        }
411    }
412
413    /// Analyze unsafe blocks in the entire file
414    fn analyze_unsafe_blocks(&self, file_path: &Path, content: &str) -> Result<Vec<UnsafeFinding>> {
415        let mut findings = Vec::new();
416        let mut in_unsafe_block = false;
417        let mut block_start = 0;
418        let mut brace_count = 0;
419
420        for (line_num, line) in content.lines().enumerate() {
421            if line.contains("unsafe {") {
422                in_unsafe_block = true;
423                block_start = line_num + 1;
424                brace_count = 1;
425            } else if in_unsafe_block {
426                brace_count += line.matches('{').count();
427                brace_count -= line.matches('}').count();
428
429                if brace_count == 0 {
430                    // End of unsafe block
431                    in_unsafe_block = false;
432
433                    // Extract the entire unsafe block
434                    let block_lines: Vec<&str> = content
435                        .lines()
436                        .skip(block_start - 1)
437                        .take(line_num - block_start + 2)
438                        .collect();
439                    let block_content = block_lines.join("\n");
440
441                    let unsafe_type = UnsafeType::UnsafeBlock;
442                    let severity = self.assess_block_severity(&block_content);
443                    let is_known_safe = self.is_known_safe_pattern(&block_content);
444                    let justification = self.extract_block_justification(&block_content);
445                    let suggestions = self.generate_block_suggestions(&block_content);
446
447                    findings.push(UnsafeFinding {
448                        file: file_path.to_path_buf(),
449                        line: block_start,
450                        column: None,
451                        unsafe_type,
452                        code_snippet: block_content,
453                        justification,
454                        is_known_safe,
455                        severity,
456                        suggestions,
457                    });
458                }
459            }
460        }
461
462        Ok(findings)
463    }
464
465    /// Determine the type of unsafe operation
466    fn determine_unsafe_type(&self, line: &str) -> UnsafeType {
467        if line.contains("transmute") {
468            UnsafeType::Transmute
469        } else if line.contains("asm!") {
470            UnsafeType::InlineAssembly
471        } else if line.contains("static mut") {
472            UnsafeType::MutableStatic
473        } else if line.contains("union") {
474            UnsafeType::UnionFieldAccess
475        } else if line.contains("*ptr")
476            || (line.contains("*") && (line.contains("as *") || line.contains("->")))
477        {
478            UnsafeType::RawPointerDeref
479        } else if line.contains("func()")
480            || (line.contains("(")
481                && line.contains(")")
482                && !line.contains("asm!")
483                && !line.contains("transmute"))
484        {
485            UnsafeType::UnsafeFunctionCall
486        } else {
487            UnsafeType::UnsafeBlock
488        }
489    }
490
491    /// Assess the severity of an unsafe operation
492    fn assess_severity(&self, unsafe_type: &UnsafeType, code: &str) -> SafetySeverity {
493        match unsafe_type {
494            UnsafeType::Transmute => SafetySeverity::Critical,
495            UnsafeType::InlineAssembly => SafetySeverity::Critical,
496            UnsafeType::MutableStatic => SafetySeverity::High,
497            UnsafeType::RawPointerDeref => {
498                if code.contains("null") || code.contains("dangling") {
499                    SafetySeverity::Critical
500                } else {
501                    SafetySeverity::High
502                }
503            }
504            UnsafeType::UnsafeFunctionCall => {
505                if self.is_known_safe_pattern(code) {
506                    SafetySeverity::Low
507                } else {
508                    SafetySeverity::Medium
509                }
510            }
511            UnsafeType::UnionFieldAccess => SafetySeverity::Medium,
512            UnsafeType::UnsafeBlock => SafetySeverity::Medium,
513        }
514    }
515
516    /// Assess the severity of an entire unsafe block
517    fn assess_block_severity(&self, block_content: &str) -> SafetySeverity {
518        let critical_patterns = ["transmute", "asm!", "null"];
519        let high_patterns = ["static mut", "*mut", "*const"];
520
521        for pattern in &critical_patterns {
522            if block_content.contains(pattern) {
523                return SafetySeverity::Critical;
524            }
525        }
526
527        for pattern in &high_patterns {
528            if block_content.contains(pattern) {
529                return SafetySeverity::High;
530            }
531        }
532
533        SafetySeverity::Medium
534    }
535
536    /// Check if a pattern is known to be safe
537    fn is_known_safe_pattern(&self, code: &str) -> bool {
538        for pattern in &self.config.allowed_patterns {
539            for signature in &pattern.signatures {
540                if code.contains(signature) {
541                    return true;
542                }
543            }
544        }
545        false
546    }
547
548    /// Extract justification from comments
549    fn extract_justification(&self, line: &str) -> Option<String> {
550        if let Some(comment_start) = line.find("//") {
551            let comment = &line[comment_start + 2..].trim();
552            if !comment.is_empty() {
553                Some(comment.to_string())
554            } else {
555                None
556            }
557        } else {
558            None
559        }
560    }
561
562    /// Extract justification from unsafe block comments
563    fn extract_block_justification(&self, block: &str) -> Option<String> {
564        let lines: Vec<&str> = block.lines().collect();
565        for line in lines {
566            if let Some(comment_start) = line.find("//") {
567                let comment = &line[comment_start + 2..].trim();
568                if comment.to_lowercase().contains("safety")
569                    || comment.to_lowercase().contains("justification")
570                    || comment.to_lowercase().contains("safe because")
571                {
572                    return Some(comment.to_string());
573                }
574            }
575        }
576        None
577    }
578
579    /// Generate suggestions for improving unsafe code
580    fn generate_suggestions(&self, unsafe_type: &UnsafeType, _code: &str) -> Vec<String> {
581        match unsafe_type {
582            UnsafeType::RawPointerDeref => vec![
583                "Consider using safe array indexing with bounds checking".to_string(),
584                "Use slice methods instead of raw pointer arithmetic".to_string(),
585                "Add explicit null pointer checks".to_string(),
586            ],
587            UnsafeType::UnsafeFunctionCall => vec![
588                "Document why this function call is safe".to_string(),
589                "Consider wrapping in a safe abstraction".to_string(),
590                "Validate all parameters before calling".to_string(),
591            ],
592            UnsafeType::Transmute => vec![
593                "Use safe type conversion methods instead".to_string(),
594                "Consider using union types for type punning".to_string(),
595                "Add size and alignment assertions".to_string(),
596            ],
597            UnsafeType::MutableStatic => vec![
598                "Use thread-local storage or synchronization".to_string(),
599                "Consider using lazy_static or once_cell".to_string(),
600                "Document thread safety guarantees".to_string(),
601            ],
602            UnsafeType::InlineAssembly => vec![
603                "Document assembly code thoroughly".to_string(),
604                "Consider using intrinsics instead".to_string(),
605                "Add extensive testing for different platforms".to_string(),
606            ],
607            UnsafeType::UnionFieldAccess => vec![
608                "Document which field is active".to_string(),
609                "Use tagged unions for safety".to_string(),
610                "Consider using enums instead".to_string(),
611            ],
612            UnsafeType::UnsafeBlock => vec![
613                "Minimize the scope of the unsafe block".to_string(),
614                "Document all safety invariants".to_string(),
615                "Add safety assertions where possible".to_string(),
616            ],
617        }
618    }
619
620    /// Generate suggestions for improving unsafe blocks
621    fn generate_block_suggestions(&self, block: &str) -> Vec<String> {
622        let mut suggestions = Vec::new();
623
624        if !block.contains("//") {
625            suggestions.push("Add comments explaining why this unsafe code is safe".to_string());
626        }
627
628        if block.lines().count() > 10 {
629            suggestions
630                .push("Consider breaking this large unsafe block into smaller pieces".to_string());
631        }
632
633        if block.contains("panic!") {
634            suggestions.push("Avoid panicking inside unsafe blocks".to_string());
635        }
636
637        suggestions.push("Add debug assertions to validate safety invariants".to_string());
638        suggestions.push("Consider creating a safe wrapper function".to_string());
639
640        suggestions
641    }
642
643    /// Check if a path should be excluded from auditing
644    fn should_exclude(&self, path: &Path) -> bool {
645        for exclude_path in &self.config.exclude_paths {
646            if path.ends_with(exclude_path)
647                || path
648                    .components()
649                    .any(|c| c.as_os_str() == exclude_path.as_os_str())
650            {
651                return true;
652            }
653        }
654        false
655    }
656
657    /// Evaluate whether the audit results pass the configured criteria
658    fn evaluate_audit_results(&self, findings: &HashMap<PathBuf, Vec<UnsafeFinding>>) -> bool {
659        if self.config.strict_mode {
660            return findings.is_empty();
661        }
662
663        // Check per-file limits
664        for file_findings in findings.values() {
665            if file_findings.len() > self.config.max_unsafe_per_file {
666                return false;
667            }
668
669            // If justification is required, check that critical findings have justification
670            if self.config.require_justification {
671                for finding in file_findings {
672                    if finding.severity >= SafetySeverity::High && finding.justification.is_none() {
673                        return false;
674                    }
675                }
676            }
677        }
678
679        true
680    }
681
682    /// Generate summary statistics
683    fn generate_summary(&self, findings: &HashMap<PathBuf, Vec<UnsafeFinding>>) -> UnsafeSummary {
684        let mut types_breakdown = HashMap::new();
685        let mut severity_breakdown = HashMap::new();
686        let mut file_counts = Vec::new();
687        let mut patterns = HashSet::new();
688
689        for (file, file_findings) in findings {
690            file_counts.push((file.clone(), file_findings.len()));
691
692            for finding in file_findings {
693                *types_breakdown
694                    .entry(finding.unsafe_type.clone())
695                    .or_insert(0) += 1;
696                *severity_breakdown
697                    .entry(finding.severity.clone())
698                    .or_insert(0) += 1;
699
700                // Extract common patterns
701                if finding.code_snippet.contains("transmute") {
702                    patterns.insert("transmute usage".to_string());
703                }
704                if finding.code_snippet.contains("*mut") || finding.code_snippet.contains("*const")
705                {
706                    patterns.insert("raw pointer usage".to_string());
707                }
708                if finding.code_snippet.contains("std::slice::from_raw_parts") {
709                    patterns.insert("slice from raw parts".to_string());
710                }
711            }
712        }
713
714        // Sort files by unsafe count
715        file_counts.sort_by(|a, b| b.1.cmp(&a.1));
716        let top_unsafe_files = file_counts.into_iter().take(10).collect();
717
718        UnsafeSummary {
719            types_breakdown,
720            severity_breakdown,
721            top_unsafe_files,
722            common_patterns: patterns.into_iter().collect(),
723        }
724    }
725
726    /// Generate recommendations for improving code safety
727    fn generate_recommendations(
728        &self,
729        findings: &HashMap<PathBuf, Vec<UnsafeFinding>>,
730    ) -> Vec<SafetyRecommendation> {
731        let mut recommendations = Vec::new();
732
733        // Analyze patterns and generate recommendations
734        let mut files_with_high_severity = Vec::new();
735        let mut files_without_justification = Vec::new();
736        let mut files_with_many_unsafe = Vec::new();
737
738        for (file, file_findings) in findings {
739            let high_severity_count = file_findings
740                .iter()
741                .filter(|f| f.severity >= SafetySeverity::High)
742                .count();
743
744            let missing_justification_count = file_findings
745                .iter()
746                .filter(|f| f.severity >= SafetySeverity::Medium && f.justification.is_none())
747                .count();
748
749            if high_severity_count > 0 {
750                files_with_high_severity.push(file.clone());
751            }
752
753            if missing_justification_count > 0 {
754                files_without_justification.push(file.clone());
755            }
756
757            if file_findings.len() > self.config.max_unsafe_per_file {
758                files_with_many_unsafe.push(file.clone());
759            }
760        }
761
762        // Generate specific recommendations
763        if !files_with_high_severity.is_empty() {
764            recommendations.push(SafetyRecommendation {
765                recommendation_type: RecommendationType::ReplaceWithSafe,
766                description: "Replace high-severity unsafe code with safe alternatives".to_string(),
767                affected_files: files_with_high_severity,
768                effort: EffortLevel::High,
769                safety_impact: SafetyImpact::Critical,
770            });
771        }
772
773        if !files_without_justification.is_empty() {
774            recommendations.push(SafetyRecommendation {
775                recommendation_type: RecommendationType::ImproveDocumentation,
776                description: "Add safety justifications for all unsafe code".to_string(),
777                affected_files: files_without_justification,
778                effort: EffortLevel::Low,
779                safety_impact: SafetyImpact::Medium,
780            });
781        }
782
783        if !files_with_many_unsafe.is_empty() {
784            recommendations.push(SafetyRecommendation {
785                recommendation_type: RecommendationType::Refactor,
786                description: "Refactor files with excessive unsafe code".to_string(),
787                affected_files: files_with_many_unsafe,
788                effort: EffortLevel::High,
789                safety_impact: SafetyImpact::High,
790            });
791        }
792
793        recommendations
794    }
795
796    /// Get the current configuration
797    pub fn config(&self) -> &UnsafeAuditConfig {
798        &self.config
799    }
800
801    /// Update the configuration
802    pub fn set_config(&mut self, config: UnsafeAuditConfig) {
803        self.config = config;
804    }
805}
806
807impl Default for UnsafeAuditor {
808    fn default() -> Self {
809        Self::new()
810    }
811}
812
813#[allow(non_snake_case)]
814#[cfg(test)]
815mod tests {
816    use super::*;
817
818    #[test]
819    fn test_unsafe_audit_config_default() {
820        let config = UnsafeAuditConfig::default();
821        assert_eq!(config.max_unsafe_per_file, 5);
822        assert!(config.require_justification);
823        assert!(!config.strict_mode);
824        assert!(!config.allowed_patterns.is_empty());
825    }
826
827    #[test]
828    fn test_unsafe_auditor_creation() {
829        let auditor = UnsafeAuditor::new();
830        assert_eq!(auditor.config().max_unsafe_per_file, 5);
831    }
832
833    #[test]
834    fn test_determine_unsafe_type() {
835        let auditor = UnsafeAuditor::new();
836
837        assert_eq!(
838            auditor.determine_unsafe_type("unsafe { *ptr }"),
839            UnsafeType::RawPointerDeref
840        );
841        assert_eq!(
842            auditor.determine_unsafe_type("unsafe { transmute(x) }"),
843            UnsafeType::Transmute
844        );
845        assert_eq!(
846            auditor.determine_unsafe_type("unsafe { static mut X }"),
847            UnsafeType::MutableStatic
848        );
849        assert_eq!(
850            auditor.determine_unsafe_type("unsafe { asm!() }"),
851            UnsafeType::InlineAssembly
852        );
853        assert_eq!(
854            auditor.determine_unsafe_type("unsafe { func() }"),
855            UnsafeType::UnsafeFunctionCall
856        );
857    }
858
859    #[test]
860    fn test_assess_severity() {
861        let auditor = UnsafeAuditor::new();
862
863        assert_eq!(
864            auditor.assess_severity(&UnsafeType::Transmute, "transmute"),
865            SafetySeverity::Critical
866        );
867        assert_eq!(
868            auditor.assess_severity(&UnsafeType::InlineAssembly, "asm!"),
869            SafetySeverity::Critical
870        );
871        assert_eq!(
872            auditor.assess_severity(&UnsafeType::MutableStatic, "static mut"),
873            SafetySeverity::High
874        );
875        assert_eq!(
876            auditor.assess_severity(&UnsafeType::RawPointerDeref, "*null"),
877            SafetySeverity::Critical
878        );
879        assert_eq!(
880            auditor.assess_severity(&UnsafeType::RawPointerDeref, "*ptr"),
881            SafetySeverity::High
882        );
883    }
884
885    #[test]
886    fn test_is_known_safe_pattern() {
887        let auditor = UnsafeAuditor::new();
888
889        assert!(auditor.is_known_safe_pattern("std::simd::f32x4::new()"));
890        assert!(auditor.is_known_safe_pattern("std::slice::from_raw_parts(ptr, len)"));
891        assert!(!auditor.is_known_safe_pattern("transmute(x)"));
892    }
893
894    #[test]
895    fn test_extract_justification() {
896        let auditor = UnsafeAuditor::new();
897
898        let result =
899            auditor.extract_justification("unsafe { *ptr } // SAFETY: ptr is guaranteed non-null");
900        assert_eq!(
901            result,
902            Some("SAFETY: ptr is guaranteed non-null".to_string())
903        );
904
905        let result = auditor.extract_justification("unsafe { *ptr }");
906        assert_eq!(result, None);
907    }
908
909    #[test]
910    fn test_generate_suggestions() {
911        let auditor = UnsafeAuditor::new();
912
913        let suggestions = auditor.generate_suggestions(&UnsafeType::RawPointerDeref, "*ptr");
914        assert!(!suggestions.is_empty());
915        assert!(suggestions.iter().any(|s| s.contains("bounds checking")));
916
917        let suggestions = auditor.generate_suggestions(&UnsafeType::Transmute, "transmute");
918        assert!(suggestions
919            .iter()
920            .any(|s| s.contains("safe type conversion")));
921    }
922
923    #[test]
924    fn test_should_exclude() {
925        let config = UnsafeAuditConfig {
926            exclude_paths: vec![PathBuf::from("target"), PathBuf::from("benches")],
927            ..Default::default()
928        };
929        let auditor = UnsafeAuditor::with_config(config);
930
931        assert!(auditor.should_exclude(Path::new("target/debug/foo")));
932        assert!(auditor.should_exclude(Path::new("benches/benchmark.rs")));
933        assert!(!auditor.should_exclude(Path::new("src/lib.rs")));
934    }
935
936    #[test]
937    fn test_unsafe_finding_creation() {
938        let finding = UnsafeFinding {
939            file: PathBuf::from("test.rs"),
940            line: 10,
941            column: Some(5),
942            unsafe_type: UnsafeType::RawPointerDeref,
943            code_snippet: "unsafe { *ptr }".to_string(),
944            justification: Some("ptr is non-null".to_string()),
945            is_known_safe: false,
946            severity: SafetySeverity::High,
947            suggestions: vec!["Use safe indexing".to_string()],
948        };
949
950        assert_eq!(finding.file, PathBuf::from("test.rs"));
951        assert_eq!(finding.line, 10);
952        assert_eq!(finding.unsafe_type, UnsafeType::RawPointerDeref);
953        assert_eq!(finding.severity, SafetySeverity::High);
954    }
955
956    #[test]
957    fn test_safety_severity_ordering() {
958        assert!(SafetySeverity::Critical > SafetySeverity::High);
959        assert!(SafetySeverity::High > SafetySeverity::Medium);
960        assert!(SafetySeverity::Medium > SafetySeverity::Low);
961        assert!(SafetySeverity::Low > SafetySeverity::Info);
962    }
963
964    #[test]
965    fn test_effort_level_ordering() {
966        assert!(EffortLevel::High > EffortLevel::Medium);
967        assert!(EffortLevel::Medium > EffortLevel::Low);
968        assert!(EffortLevel::Low > EffortLevel::Minimal);
969    }
970}