Skip to main content

sbom_tools/verification/
model_dir.rs

1//! Model-weight integrity verification.
2//!
3//! Verifies the on-disk weight files of `MachineLearningModel` / `Data`
4//! components against the hashes recorded in an SBOM (typically injected by the
5//! HuggingFace enricher). For each such component this:
6//!
7//! 1. locates candidate weight files under a model directory, looking both for
8//!    direct filenames AND the HuggingFace cache snapshot layout where blob
9//!    files are named by their SHA-256 content hash, then
10//! 2. verifies the located file against the component's hash via the shared
11//!    [`verify_file_hash`](crate::verification::verify_file_hash).
12//!
13//! The result is a per-component pass / fail / missing report suitable for CI
14//! gating.
15
16use std::collections::HashMap;
17use std::path::{Path, PathBuf};
18
19use serde::{Deserialize, Serialize};
20
21use crate::model::{Component, ComponentType, HashAlgorithm, NormalizedSbom};
22use crate::verification::verify_file_hash;
23
24/// Outcome of verifying a single model component's weights.
25#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
26pub enum ModelVerifyResult {
27    /// A weight file was located and its hash matched the SBOM.
28    Verified,
29    /// A weight file was located but its hash did NOT match (possible tampering).
30    Mismatch,
31    /// The component declares hashes but no matching weight file was found.
32    Missing,
33    /// The component declares no usable (SHA-256/384/512) hash to verify against.
34    NoHash,
35}
36
37impl ModelVerifyResult {
38    /// Short status label.
39    #[must_use]
40    pub const fn label(&self) -> &'static str {
41        match self {
42            Self::Verified => "VERIFIED",
43            Self::Mismatch => "MISMATCH",
44            Self::Missing => "MISSING",
45            Self::NoHash => "NO-HASH",
46        }
47    }
48}
49
50/// Per-component model-weight verification record.
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ComponentModelVerification {
53    /// Component name.
54    pub name: String,
55    /// Component version.
56    pub version: Option<String>,
57    /// Verification outcome.
58    pub result: ModelVerifyResult,
59    /// Hash value (hex) that was checked, when applicable.
60    pub hash: Option<String>,
61    /// Path of the weight file that was located, when applicable (relative to
62    /// the model directory for readability).
63    pub file: Option<String>,
64}
65
66/// Aggregate model-weight verification report.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct ModelVerifyReport {
69    /// Model directory that was searched.
70    pub model_dir: String,
71    /// ML-model / dataset components inspected.
72    pub total_models: usize,
73    /// Components verified successfully.
74    pub verified_count: usize,
75    /// Components whose located weight file mismatched.
76    pub mismatch_count: usize,
77    /// Components with hashes but no located weight file.
78    pub missing_count: usize,
79    /// Components without a usable hash to verify.
80    pub no_hash_count: usize,
81    /// Per-component records.
82    pub components: Vec<ComponentModelVerification>,
83}
84
85impl ModelVerifyReport {
86    /// Whether the run had any failing component (mismatch or missing).
87    #[must_use]
88    pub const fn has_failures(&self) -> bool {
89        self.mismatch_count > 0 || self.missing_count > 0
90    }
91}
92
93/// Whether a hash algorithm is one we can verify a located file against.
94///
95/// We compute SHA-256 / SHA-512 over candidate files; SHA-384 shares SHA-512's
96/// preimage but a distinct digest, so only the two directly-computable forms are
97/// treated as verifiable here (matching `verify_file_hash`).
98const fn is_verifiable(alg: &HashAlgorithm) -> bool {
99    matches!(alg, HashAlgorithm::Sha256 | HashAlgorithm::Sha512)
100}
101
102/// Verify the weight files of all model/dataset components in `sbom` against the
103/// files found under `model_dir`.
104#[must_use]
105pub fn verify_model_dir(sbom: &NormalizedSbom, model_dir: &Path) -> ModelVerifyReport {
106    // Canonicalize the model-dir root once so symlink-escape detection (below)
107    // compares against a fully-resolved root. If the root itself can't be
108    // canonicalized (e.g. it does not exist), fall back to the path as given;
109    // the walk will simply find nothing.
110    let root = std::fs::canonicalize(model_dir).unwrap_or_else(|_| model_dir.to_path_buf());
111
112    // Index files by basename (for direct-filename matches) once, so a large
113    // model directory is walked a single time. Paths that resolve outside the
114    // root (via symlinks) are excluded by the index.
115    let index = FileIndex::build(&root);
116
117    let mut report = ModelVerifyReport {
118        model_dir: model_dir.display().to_string(),
119        total_models: 0,
120        verified_count: 0,
121        mismatch_count: 0,
122        missing_count: 0,
123        no_hash_count: 0,
124        components: Vec::new(),
125    };
126
127    for component in sbom.components.values() {
128        if !is_model_like(component) {
129            continue;
130        }
131        report.total_models += 1;
132
133        let record = verify_component(component, &root, &index);
134        match record.result {
135            ModelVerifyResult::Verified => report.verified_count += 1,
136            ModelVerifyResult::Mismatch => report.mismatch_count += 1,
137            ModelVerifyResult::Missing => report.missing_count += 1,
138            ModelVerifyResult::NoHash => report.no_hash_count += 1,
139        }
140        report.components.push(record);
141    }
142
143    report
144}
145
146/// Components whose weights we attempt to verify: trained models and datasets.
147fn is_model_like(component: &Component) -> bool {
148    matches!(
149        component.component_type,
150        ComponentType::MachineLearningModel | ComponentType::Data
151    )
152}
153
154/// Verify a single component, returning its record.
155fn verify_component(
156    component: &Component,
157    model_dir: &Path,
158    index: &FileIndex,
159) -> ComponentModelVerification {
160    let make = |result, hash: Option<String>, file: Option<String>| ComponentModelVerification {
161        name: component.name.clone(),
162        version: component.version.clone(),
163        result,
164        hash,
165        file,
166    };
167
168    // Only consider hashes we can recompute over a file.
169    let verifiable: Vec<_> = component
170        .hashes
171        .iter()
172        .filter(|h| is_verifiable(&h.algorithm))
173        .collect();
174
175    if verifiable.is_empty() {
176        return make(ModelVerifyResult::NoHash, None, None);
177    }
178
179    // Candidate filenames to look for, in addition to sha256-named blobs:
180    // any external-reference / model-card filename heuristics would be noisy, so
181    // we rely on (a) the hash-named blob (HF cache layout) and (b) the
182    // component name as a filename stem.
183    let name_candidates = filename_candidates(component);
184
185    let mut last_missing_hash: Option<String> = None;
186
187    for hash in verifiable {
188        let hash_hex = hash.value.to_lowercase();
189        last_missing_hash = Some(hash_hex.clone());
190
191        // 1. HuggingFace cache layout: a blob file is literally named by its
192        //    sha256. A direct hit means the bytes are present under that name.
193        if let Some(path) = index.by_basename(&hash_hex) {
194            return verify_against(component, &hash_hex, path, model_dir);
195        }
196
197        // 2. Direct filenames (e.g. `model.safetensors`, `<name>.safetensors`).
198        for candidate in &name_candidates {
199            if let Some(path) = index.by_basename(candidate) {
200                return verify_against(component, &hash_hex, path, model_dir);
201            }
202        }
203    }
204
205    make(ModelVerifyResult::Missing, last_missing_hash, None)
206}
207
208/// Run `verify_file_hash` for a located file and build the record.
209fn verify_against(
210    component: &Component,
211    hash_hex: &str,
212    path: &Path,
213    model_dir: &Path,
214) -> ComponentModelVerification {
215    let rel = path
216        .strip_prefix(model_dir)
217        .unwrap_or(path)
218        .display()
219        .to_string();
220    let make = |result| ComponentModelVerification {
221        name: component.name.clone(),
222        version: component.version.clone(),
223        result,
224        hash: Some(hash_hex.to_string()),
225        file: Some(rel.clone()),
226    };
227
228    match verify_file_hash(path, hash_hex) {
229        Ok(r) if r.verified => make(ModelVerifyResult::Verified),
230        Ok(_) => make(ModelVerifyResult::Mismatch),
231        // An I/O error on a located file is treated as a mismatch: the file is
232        // present (it was indexed) but unreadable, which is a verification
233        // failure, not a clean "missing".
234        Err(_) => make(ModelVerifyResult::Mismatch),
235    }
236}
237
238/// Candidate weight filenames for a component, by name.
239///
240/// Real weight files are not named after the component in the HF layout (they
241/// are sha256-named blobs, handled separately), but locally-laid-out model
242/// directories often use `model.*` or `<name>.*`. These are basename matches,
243/// so the directory walk handles any nesting.
244fn filename_candidates(component: &Component) -> Vec<String> {
245    let exts = [
246        "safetensors",
247        "bin",
248        "pt",
249        "pth",
250        "onnx",
251        "gguf",
252        "ggml",
253        "h5",
254        "pb",
255        "tflite",
256    ];
257    let stems = ["model", "pytorch_model", component.name.as_str()];
258
259    let mut out = Vec::new();
260    for stem in stems {
261        if stem.is_empty() {
262            continue;
263        }
264        for ext in exts {
265            out.push(format!("{stem}.{ext}"));
266        }
267    }
268    out
269}
270
271/// A flat index of every file under a directory, keyed by basename.
272///
273/// The HuggingFace cache stores weight bytes as `blobs/<sha256>` with
274/// human-named symlinks under `snapshots/<rev>/`; indexing by basename lets us
275/// match both the sha256-named blob and a plain `model.safetensors` regardless
276/// of nesting. When several files share a basename the first seen wins; that is
277/// acceptable because hash verification still rejects a wrong file.
278///
279/// Indexed paths are stored in canonicalized form and are guaranteed to resolve
280/// *inside* the model-dir root: a symlink (or a `..` segment) that escapes the
281/// root is skipped, so `verify --model-dir` can never be tricked into reading a
282/// file outside the tree it was pointed at. HuggingFace's intra-tree
283/// `snapshots → blobs` symlinks still resolve fine because they stay under root.
284struct FileIndex {
285    by_name: HashMap<String, PathBuf>,
286}
287
288impl FileIndex {
289    /// Build the index from a *canonicalized* `root`. Every candidate path is
290    /// itself canonicalized (which follows symlinks) and only retained when the
291    /// resolved path is still within `root`; this is the symlink-escape bound.
292    fn build(root: &Path) -> Self {
293        let mut by_name = HashMap::new();
294        let mut stack = vec![root.to_path_buf()];
295        // Directories are canonical here, so a `visited` set makes the walk
296        // robust against symlinked-directory cycles within the tree.
297        let mut visited: std::collections::HashSet<PathBuf> = std::collections::HashSet::new();
298
299        while let Some(dir) = stack.pop() {
300            if !visited.insert(dir.clone()) {
301                continue;
302            }
303            let Ok(entries) = std::fs::read_dir(&dir) else {
304                continue;
305            };
306            for entry in entries.flatten() {
307                let path = entry.path();
308                // Resolve the entry fully (follows symlinks, normalizes `..`).
309                // A path that fails to resolve (dangling symlink) is skipped.
310                let Ok(resolved) = std::fs::canonicalize(&path) else {
311                    continue;
312                };
313                // Reject anything that escapes the model-dir root. Without this a
314                // crafted `model.safetensors -> /etc/passwd` (or `../secret`)
315                // symlink would let an attacker have the verifier read an
316                // arbitrary file outside the directory under audit.
317                if !resolved.starts_with(root) {
318                    continue;
319                }
320                let meta = match std::fs::metadata(&resolved) {
321                    Ok(m) => m,
322                    Err(_) => continue,
323                };
324                if meta.is_dir() {
325                    stack.push(resolved);
326                } else if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
327                    // Key on the on-disk basename (e.g. the human-readable
328                    // snapshot name), but store the bounded, resolved path so the
329                    // subsequent hash read targets the in-tree bytes.
330                    by_name
331                        .entry(name.to_lowercase())
332                        .or_insert_with(|| resolved.clone());
333                }
334            }
335        }
336
337        Self { by_name }
338    }
339
340    /// Look up a file by basename (case-insensitive).
341    fn by_basename(&self, name: &str) -> Option<&Path> {
342        self.by_name.get(&name.to_lowercase()).map(PathBuf::as_path)
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use crate::model::{DocumentMetadata, Hash};
350    use sha2::{Digest, Sha256};
351    use std::fs;
352
353    fn sha256_hex(bytes: &[u8]) -> String {
354        let mut h = Sha256::new();
355        h.update(bytes);
356        h.finalize().iter().map(|b| format!("{b:02x}")).collect()
357    }
358
359    fn model_component(name: &str, hash_hex: &str) -> Component {
360        let mut c = Component::new(name.to_string(), format!("{name}-ref"))
361            .with_version("1.0.0".to_string());
362        c.component_type = ComponentType::MachineLearningModel;
363        c.hashes
364            .push(Hash::new(HashAlgorithm::Sha256, hash_hex.to_string()));
365        c
366    }
367
368    #[test]
369    fn verifies_against_hf_blob_named_by_sha256() {
370        let dir = tempfile::tempdir().unwrap();
371        let weights = b"fake model weights";
372        let hex = sha256_hex(weights);
373
374        // HuggingFace cache layout: blobs/<sha256>.
375        let blobs = dir.path().join("blobs");
376        fs::create_dir_all(&blobs).unwrap();
377        fs::write(blobs.join(&hex), weights).unwrap();
378
379        let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
380        sbom.add_component(model_component("bert", &hex));
381
382        let report = verify_model_dir(&sbom, dir.path());
383        assert_eq!(report.total_models, 1);
384        assert_eq!(report.verified_count, 1);
385        assert_eq!(report.components[0].result, ModelVerifyResult::Verified);
386        assert!(!report.has_failures());
387    }
388
389    #[test]
390    fn verifies_against_direct_filename() {
391        let dir = tempfile::tempdir().unwrap();
392        let weights = b"safetensors bytes";
393        let hex = sha256_hex(weights);
394        fs::write(dir.path().join("model.safetensors"), weights).unwrap();
395
396        let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
397        sbom.add_component(model_component("bert", &hex));
398
399        let report = verify_model_dir(&sbom, dir.path());
400        assert_eq!(report.verified_count, 1);
401        assert_eq!(
402            report.components[0].file.as_deref(),
403            Some("model.safetensors")
404        );
405    }
406
407    #[test]
408    fn detects_tampering_as_mismatch() {
409        let dir = tempfile::tempdir().unwrap();
410        // The file's real content does not match the SBOM hash → tampering.
411        fs::write(dir.path().join("model.safetensors"), b"tampered bytes").unwrap();
412        let claimed = sha256_hex(b"original bytes");
413
414        let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
415        sbom.add_component(model_component("bert", &claimed));
416
417        let report = verify_model_dir(&sbom, dir.path());
418        assert_eq!(report.mismatch_count, 1);
419        assert_eq!(report.components[0].result, ModelVerifyResult::Mismatch);
420        assert!(report.has_failures());
421    }
422
423    #[test]
424    fn reports_missing_when_no_file_found() {
425        let dir = tempfile::tempdir().unwrap();
426        let hex = sha256_hex(b"weights that are not on disk");
427
428        let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
429        sbom.add_component(model_component("bert", &hex));
430
431        let report = verify_model_dir(&sbom, dir.path());
432        assert_eq!(report.missing_count, 1);
433        assert_eq!(report.components[0].result, ModelVerifyResult::Missing);
434    }
435
436    #[test]
437    fn reports_no_hash_when_only_weak_hash_present() {
438        let dir = tempfile::tempdir().unwrap();
439        let mut c = Component::new("bert".to_string(), "bert-ref".to_string());
440        c.component_type = ComponentType::MachineLearningModel;
441        c.hashes
442            .push(Hash::new(HashAlgorithm::Md5, "deadbeef".to_string()));
443
444        let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
445        sbom.add_component(c);
446
447        let report = verify_model_dir(&sbom, dir.path());
448        assert_eq!(report.no_hash_count, 1);
449        assert_eq!(report.components[0].result, ModelVerifyResult::NoHash);
450    }
451
452    #[cfg(unix)]
453    #[test]
454    fn does_not_follow_symlink_escaping_model_dir() {
455        use std::os::unix::fs::symlink;
456
457        // The real weight bytes live OUTSIDE the model directory.
458        let outside = tempfile::tempdir().unwrap();
459        let weights = b"weights that live outside the model dir";
460        let hex = sha256_hex(weights);
461        let secret = outside.path().join("model.safetensors");
462        fs::write(&secret, weights).unwrap();
463
464        // Inside the model dir, a symlink with a plausible weight name points at
465        // the out-of-tree file. A naive verifier would follow it and report
466        // VERIFIED, leaking the result of reading an arbitrary path.
467        let model_dir = tempfile::tempdir().unwrap();
468        symlink(&secret, model_dir.path().join("model.safetensors")).unwrap();
469
470        let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
471        sbom.add_component(model_component("escape", &hex));
472
473        let report = verify_model_dir(&sbom, model_dir.path());
474        assert_eq!(report.total_models, 1);
475        assert_eq!(
476            report.verified_count, 0,
477            "a symlink escaping the model dir must not be followed/verified"
478        );
479        assert_eq!(
480            report.components[0].result,
481            ModelVerifyResult::Missing,
482            "out-of-tree symlink target is treated as no in-tree file found"
483        );
484    }
485
486    #[cfg(unix)]
487    #[test]
488    fn follows_intra_tree_symlink_like_hf_cache() {
489        use std::os::unix::fs::symlink;
490
491        // HuggingFace layout: blobs/<sha256> with a snapshots/ symlink that stays
492        // WITHIN the model dir. This must still verify (the escape guard only
493        // rejects targets that leave the root).
494        let dir = tempfile::tempdir().unwrap();
495        let weights = b"in-tree hf blob bytes";
496        let hex = sha256_hex(weights);
497
498        let blobs = dir.path().join("blobs");
499        let snapshots = dir.path().join("snapshots").join("main");
500        fs::create_dir_all(&blobs).unwrap();
501        fs::create_dir_all(&snapshots).unwrap();
502        let blob = blobs.join(&hex);
503        fs::write(&blob, weights).unwrap();
504        symlink(&blob, snapshots.join("model.safetensors")).unwrap();
505
506        let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
507        sbom.add_component(model_component("bert", &hex));
508
509        let report = verify_model_dir(&sbom, dir.path());
510        assert_eq!(
511            report.verified_count, 1,
512            "intra-tree HF snapshot→blob symlink must still verify"
513        );
514    }
515
516    #[test]
517    fn ignores_non_model_components() {
518        let dir = tempfile::tempdir().unwrap();
519        let mut c = Component::new("lib".to_string(), "lib-ref".to_string());
520        c.component_type = ComponentType::Library;
521        c.hashes
522            .push(Hash::new(HashAlgorithm::Sha256, "a".repeat(64)));
523
524        let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
525        sbom.add_component(c);
526
527        let report = verify_model_dir(&sbom, dir.path());
528        assert_eq!(report.total_models, 0, "library components are not models");
529    }
530}