Skip to main content

rustinel_core/
risk.rs

1use crate::lockfile::LockfileModel;
2use crate::signals::{RiskSignal, Severity};
3use serde::{Deserialize, Serialize};
4use std::collections::BTreeMap;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ProjectRisk {
8    pub score: u8,
9    pub level: RiskLevel,
10    /// Highest single-package score, useful for `max_package_score` policy checks.
11    pub max_package_score: u8,
12    /// Per-package scores, sorted by package id for determinism.
13    pub packages: Vec<PackageRisk>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct PackageRisk {
18    pub package: String,
19    pub score: u8,
20    pub level: RiskLevel,
21}
22
23#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
24#[serde(rename_all = "snake_case")]
25pub enum RiskLevel {
26    Low,
27    Medium,
28    High,
29    Critical,
30}
31
32impl RiskLevel {
33    pub fn as_str(&self) -> &'static str {
34        match self {
35            RiskLevel::Low => "low",
36            RiskLevel::Medium => "medium",
37            RiskLevel::High => "high",
38            RiskLevel::Critical => "critical",
39        }
40    }
41}
42
43/// Decay factor for repeated heuristic findings of the same class.
44const HEURISTIC_DECAY: f64 = 0.5;
45
46fn is_advisory(signal: &RiskSignal) -> bool {
47    signal.id.starts_with("advisory_")
48}
49
50/// Compute the project risk score from collected signals (0–100).
51///
52/// The aggregation is deliberately *not* a flat sum, to avoid false-positive
53/// inflation on large dependency trees:
54///
55/// - **Advisory findings** (real, matched vulnerabilities) are summed in full —
56///   each additional known vulnerability genuinely adds risk. A `Critical`
57///   advisory pins the project score to 100.
58/// - **Heuristic findings** (FFI, build scripts, `unsafe`, duplicate versions,
59///   license) get *diminishing returns per class*: the largest finding of a
60///   class counts in full, the next at 50%, then 25%, … so a project with 30
61///   `-sys` crates is not scored as 30× the risk of one.
62///
63/// Per-package scores use a plain saturating sum so that a single highly-risky
64/// package can still trip `max_package_score`.
65pub fn score_project(_lock: &LockfileModel, signals: &[RiskSignal]) -> ProjectRisk {
66    let mut critical = false;
67    let mut advisory_sum: f64 = 0.0;
68    // Heuristic weights grouped by signal id, for diminishing aggregation.
69    let mut heuristic_by_id: BTreeMap<&str, Vec<u8>> = BTreeMap::new();
70    let mut per_package: BTreeMap<&str, u16> = BTreeMap::new();
71    let mut package_critical: BTreeMap<&str, bool> = BTreeMap::new();
72
73    for signal in signals {
74        let pkg = per_package.entry(&signal.package).or_insert(0);
75        *pkg = pkg.saturating_add(signal.weight as u16);
76        if signal.severity == Severity::Critical {
77            critical = true;
78            package_critical.insert(&signal.package, true);
79        }
80        if is_advisory(signal) {
81            advisory_sum += signal.weight as f64;
82        } else if signal.weight > 0 {
83            heuristic_by_id
84                .entry(&signal.id)
85                .or_default()
86                .push(signal.weight);
87        }
88    }
89
90    let mut heuristic_sum = 0.0;
91    for weights in heuristic_by_id.values_mut() {
92        weights.sort_unstable_by(|a, b| b.cmp(a)); // largest first
93        for (i, w) in weights.iter().enumerate() {
94            heuristic_sum += (*w as f64) * HEURISTIC_DECAY.powi(i as i32);
95        }
96    }
97
98    let raw = advisory_sum + heuristic_sum;
99    let score = if critical {
100        100
101    } else {
102        raw.round().min(100.0) as u8
103    };
104
105    let packages = per_package
106        .into_iter()
107        .map(|(name, raw)| {
108            let s = if package_critical.get(name).copied().unwrap_or(false) {
109                100
110            } else {
111                raw.min(100) as u8
112            };
113            PackageRisk {
114                package: name.to_string(),
115                score: s,
116                level: level_for_score(s),
117            }
118        })
119        .collect::<Vec<_>>();
120
121    let max_package_score = packages.iter().map(|p| p.score).max().unwrap_or(0);
122
123    ProjectRisk {
124        score,
125        level: level_for_score(score),
126        max_package_score,
127        packages,
128    }
129}
130
131/// A transparent breakdown of how the project score was computed.
132#[derive(Debug, Clone)]
133pub struct ScoreExplanation {
134    /// `(label, points)` contributions, largest first.
135    pub contributions: Vec<(String, f64)>,
136    pub total: u8,
137    /// True when a critical advisory pinned the score to 100.
138    pub critical_pin: bool,
139}
140
141/// Explain the score: per-advisory full weights plus per-class diminishing
142/// heuristic contributions. The summed `total` is guaranteed to equal
143/// [`score_project`]'s score for the same signals (asserted in tests).
144pub fn explain(signals: &[RiskSignal]) -> ScoreExplanation {
145    let mut critical = false;
146    let mut contributions: Vec<(String, f64)> = Vec::new();
147    let mut heuristic_by_id: BTreeMap<&str, Vec<u8>> = BTreeMap::new();
148
149    for signal in signals {
150        if signal.severity == Severity::Critical {
151            critical = true;
152        }
153        if is_advisory(signal) {
154            contributions.push((
155                format!(
156                    "{} on {}",
157                    signal.id.trim_start_matches("advisory_"),
158                    signal.package
159                ),
160                signal.weight as f64,
161            ));
162        } else if signal.weight > 0 {
163            heuristic_by_id
164                .entry(&signal.id)
165                .or_default()
166                .push(signal.weight);
167        }
168    }
169
170    for (id, weights) in heuristic_by_id.iter_mut() {
171        weights.sort_unstable_by(|a, b| b.cmp(a));
172        let sum: f64 = weights
173            .iter()
174            .enumerate()
175            .map(|(i, w)| (*w as f64) * HEURISTIC_DECAY.powi(i as i32))
176            .sum();
177        contributions.push((format!("{} (×{})", id, weights.len()), sum));
178    }
179
180    let raw: f64 = contributions.iter().map(|(_, p)| p).sum();
181    let total = if critical {
182        100
183    } else {
184        raw.round().min(100.0) as u8
185    };
186
187    contributions.sort_by(|a, b| {
188        b.1.partial_cmp(&a.1)
189            .unwrap_or(std::cmp::Ordering::Equal)
190            .then_with(|| a.0.cmp(&b.0))
191    });
192
193    ScoreExplanation {
194        contributions,
195        total,
196        critical_pin: critical,
197    }
198}
199
200pub fn level_for_score(score: u8) -> RiskLevel {
201    match score {
202        0..=19 => RiskLevel::Low,
203        20..=49 => RiskLevel::Medium,
204        50..=79 => RiskLevel::High,
205        _ => RiskLevel::Critical,
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use crate::signals::RiskSignal;
213
214    fn sig(pkg: &str, sev: Severity, weight: u8) -> RiskSignal {
215        sig_id("x", pkg, sev, weight)
216    }
217
218    fn sig_id(id: &str, pkg: &str, sev: Severity, weight: u8) -> RiskSignal {
219        RiskSignal {
220            id: id.into(),
221            package: pkg.into(),
222            severity: sev,
223            weight,
224            confidence: 1.0,
225            evidence: vec![],
226            recommendation: String::new(),
227        }
228    }
229
230    fn empty_lock() -> LockfileModel {
231        LockfileModel {
232            path: "Cargo.lock".into(),
233            version: None,
234            packages: vec![],
235        }
236    }
237
238    #[test]
239    fn empty_is_low() {
240        let r = score_project(&empty_lock(), &[]);
241        assert_eq!(r.score, 0);
242        assert_eq!(r.level, RiskLevel::Low);
243    }
244
245    #[test]
246    fn diminishing_within_same_class() {
247        // Two findings of the SAME id => diminishing: 20 + 12*0.5 = 26.
248        let signals = vec![
249            sig("a@1", Severity::High, 20),
250            sig("a@1", Severity::Medium, 12),
251        ];
252        let r = score_project(&empty_lock(), &signals);
253        assert_eq!(r.score, 26);
254        // Per-package score still sums fully so a single bad crate can trip policy.
255        assert_eq!(r.max_package_score, 32);
256    }
257
258    #[test]
259    fn different_classes_sum_fully() {
260        let signals = vec![
261            sig_id("native_ffi_detected", "a@1", Severity::Medium, 14),
262            sig_id("build_script_present", "a@1", Severity::Medium, 10),
263        ];
264        let r = score_project(&empty_lock(), &signals);
265        assert_eq!(r.score, 24);
266    }
267
268    #[test]
269    fn many_low_findings_do_not_saturate() {
270        // 30 identical Low/8 FFI findings must NOT blow up to 100.
271        let signals: Vec<RiskSignal> = (0..30)
272            .map(|i| sig_id("native_ffi_detected", &format!("c{i}@1"), Severity::Low, 8))
273            .collect();
274        let r = score_project(&empty_lock(), &signals);
275        // 8 * (1 + 0.5 + 0.25 + ...) -> converges to 16.
276        assert!(r.score <= 16, "score was {}", r.score);
277    }
278
279    #[test]
280    fn explain_total_matches_score() {
281        let cases: Vec<Vec<RiskSignal>> = vec![
282            vec![],
283            vec![sig_id("native_ffi_detected", "a@1", Severity::Low, 8)],
284            vec![
285                sig_id("native_ffi_detected", "a@1", Severity::Low, 8),
286                sig_id("native_ffi_detected", "b@1", Severity::Low, 8),
287                sig_id("build_script_present", "a@1", Severity::Medium, 10),
288                sig_id("advisory_RUSTSEC-1", "v@1", Severity::High, 30),
289            ],
290            vec![sig_id("advisory_RUSTSEC-X", "c@1", Severity::Critical, 60)],
291        ];
292        for signals in cases {
293            let score = score_project(&empty_lock(), &signals).score;
294            let ex = explain(&signals);
295            assert_eq!(ex.total, score, "explain/score drift for {signals:?}");
296        }
297    }
298
299    #[test]
300    fn advisories_sum_fully() {
301        let signals = vec![
302            sig_id("advisory_RUSTSEC-1", "a@1", Severity::High, 30),
303            sig_id("advisory_RUSTSEC-2", "b@1", Severity::High, 30),
304        ];
305        let r = score_project(&empty_lock(), &signals);
306        assert_eq!(r.score, 60);
307    }
308
309    #[test]
310    fn critical_pins_to_100() {
311        let signals = vec![sig("a@1", Severity::Critical, 5)];
312        let r = score_project(&empty_lock(), &signals);
313        assert_eq!(r.score, 100);
314        assert_eq!(r.level, RiskLevel::Critical);
315    }
316
317    #[test]
318    fn level_boundaries() {
319        assert_eq!(level_for_score(0), RiskLevel::Low);
320        assert_eq!(level_for_score(19), RiskLevel::Low);
321        assert_eq!(level_for_score(20), RiskLevel::Medium);
322        assert_eq!(level_for_score(49), RiskLevel::Medium);
323        assert_eq!(level_for_score(50), RiskLevel::High);
324        assert_eq!(level_for_score(79), RiskLevel::High);
325        assert_eq!(level_for_score(80), RiskLevel::Critical);
326    }
327}