1use std::collections::HashMap;
17use std::path::{Path, PathBuf};
18
19use serde::{Deserialize, Serialize};
20
21use crate::model::{Component, ComponentType, HashAlgorithm, NormalizedSbom};
22use crate::verification::verify_file_hash;
23
24#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
26pub enum ModelVerifyResult {
27 Verified,
29 Mismatch,
31 Missing,
33 NoHash,
35}
36
37impl ModelVerifyResult {
38 #[must_use]
40 pub const fn label(&self) -> &'static str {
41 match self {
42 Self::Verified => "VERIFIED",
43 Self::Mismatch => "MISMATCH",
44 Self::Missing => "MISSING",
45 Self::NoHash => "NO-HASH",
46 }
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ComponentModelVerification {
53 pub name: String,
55 pub version: Option<String>,
57 pub result: ModelVerifyResult,
59 pub hash: Option<String>,
61 pub file: Option<String>,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct ModelVerifyReport {
69 pub model_dir: String,
71 pub total_models: usize,
73 pub verified_count: usize,
75 pub mismatch_count: usize,
77 pub missing_count: usize,
79 pub no_hash_count: usize,
81 pub components: Vec<ComponentModelVerification>,
83}
84
85impl ModelVerifyReport {
86 #[must_use]
88 pub const fn has_failures(&self) -> bool {
89 self.mismatch_count > 0 || self.missing_count > 0
90 }
91}
92
93const fn is_verifiable(alg: &HashAlgorithm) -> bool {
99 matches!(alg, HashAlgorithm::Sha256 | HashAlgorithm::Sha512)
100}
101
102#[must_use]
105pub fn verify_model_dir(sbom: &NormalizedSbom, model_dir: &Path) -> ModelVerifyReport {
106 let root = std::fs::canonicalize(model_dir).unwrap_or_else(|_| model_dir.to_path_buf());
111
112 let index = FileIndex::build(&root);
116
117 let mut report = ModelVerifyReport {
118 model_dir: model_dir.display().to_string(),
119 total_models: 0,
120 verified_count: 0,
121 mismatch_count: 0,
122 missing_count: 0,
123 no_hash_count: 0,
124 components: Vec::new(),
125 };
126
127 for component in sbom.components.values() {
128 if !is_model_like(component) {
129 continue;
130 }
131 report.total_models += 1;
132
133 let record = verify_component(component, &root, &index);
134 match record.result {
135 ModelVerifyResult::Verified => report.verified_count += 1,
136 ModelVerifyResult::Mismatch => report.mismatch_count += 1,
137 ModelVerifyResult::Missing => report.missing_count += 1,
138 ModelVerifyResult::NoHash => report.no_hash_count += 1,
139 }
140 report.components.push(record);
141 }
142
143 report
144}
145
146fn is_model_like(component: &Component) -> bool {
148 matches!(
149 component.component_type,
150 ComponentType::MachineLearningModel | ComponentType::Data
151 )
152}
153
154fn verify_component(
156 component: &Component,
157 model_dir: &Path,
158 index: &FileIndex,
159) -> ComponentModelVerification {
160 let make = |result, hash: Option<String>, file: Option<String>| ComponentModelVerification {
161 name: component.name.clone(),
162 version: component.version.clone(),
163 result,
164 hash,
165 file,
166 };
167
168 let verifiable: Vec<_> = component
170 .hashes
171 .iter()
172 .filter(|h| is_verifiable(&h.algorithm))
173 .collect();
174
175 if verifiable.is_empty() {
176 return make(ModelVerifyResult::NoHash, None, None);
177 }
178
179 let name_candidates = filename_candidates(component);
184
185 let mut last_missing_hash: Option<String> = None;
186
187 for hash in verifiable {
188 let hash_hex = hash.value.to_lowercase();
189 last_missing_hash = Some(hash_hex.clone());
190
191 if let Some(path) = index.by_basename(&hash_hex) {
194 return verify_against(component, &hash_hex, path, model_dir);
195 }
196
197 for candidate in &name_candidates {
199 if let Some(path) = index.by_basename(candidate) {
200 return verify_against(component, &hash_hex, path, model_dir);
201 }
202 }
203 }
204
205 make(ModelVerifyResult::Missing, last_missing_hash, None)
206}
207
208fn verify_against(
210 component: &Component,
211 hash_hex: &str,
212 path: &Path,
213 model_dir: &Path,
214) -> ComponentModelVerification {
215 let rel = path
216 .strip_prefix(model_dir)
217 .unwrap_or(path)
218 .display()
219 .to_string();
220 let make = |result| ComponentModelVerification {
221 name: component.name.clone(),
222 version: component.version.clone(),
223 result,
224 hash: Some(hash_hex.to_string()),
225 file: Some(rel.clone()),
226 };
227
228 match verify_file_hash(path, hash_hex) {
229 Ok(r) if r.verified => make(ModelVerifyResult::Verified),
230 Ok(_) => make(ModelVerifyResult::Mismatch),
231 Err(_) => make(ModelVerifyResult::Mismatch),
235 }
236}
237
238fn filename_candidates(component: &Component) -> Vec<String> {
245 let exts = [
246 "safetensors",
247 "bin",
248 "pt",
249 "pth",
250 "onnx",
251 "gguf",
252 "ggml",
253 "h5",
254 "pb",
255 "tflite",
256 ];
257 let stems = ["model", "pytorch_model", component.name.as_str()];
258
259 let mut out = Vec::new();
260 for stem in stems {
261 if stem.is_empty() {
262 continue;
263 }
264 for ext in exts {
265 out.push(format!("{stem}.{ext}"));
266 }
267 }
268 out
269}
270
271struct FileIndex {
285 by_name: HashMap<String, PathBuf>,
286}
287
288impl FileIndex {
289 fn build(root: &Path) -> Self {
293 let mut by_name = HashMap::new();
294 let mut stack = vec![root.to_path_buf()];
295 let mut visited: std::collections::HashSet<PathBuf> = std::collections::HashSet::new();
298
299 while let Some(dir) = stack.pop() {
300 if !visited.insert(dir.clone()) {
301 continue;
302 }
303 let Ok(entries) = std::fs::read_dir(&dir) else {
304 continue;
305 };
306 for entry in entries.flatten() {
307 let path = entry.path();
308 let Ok(resolved) = std::fs::canonicalize(&path) else {
311 continue;
312 };
313 if !resolved.starts_with(root) {
318 continue;
319 }
320 let meta = match std::fs::metadata(&resolved) {
321 Ok(m) => m,
322 Err(_) => continue,
323 };
324 if meta.is_dir() {
325 stack.push(resolved);
326 } else if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
327 by_name
331 .entry(name.to_lowercase())
332 .or_insert_with(|| resolved.clone());
333 }
334 }
335 }
336
337 Self { by_name }
338 }
339
340 fn by_basename(&self, name: &str) -> Option<&Path> {
342 self.by_name.get(&name.to_lowercase()).map(PathBuf::as_path)
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use crate::model::{DocumentMetadata, Hash};
350 use sha2::{Digest, Sha256};
351 use std::fs;
352
353 fn sha256_hex(bytes: &[u8]) -> String {
354 let mut h = Sha256::new();
355 h.update(bytes);
356 h.finalize().iter().map(|b| format!("{b:02x}")).collect()
357 }
358
359 fn model_component(name: &str, hash_hex: &str) -> Component {
360 let mut c = Component::new(name.to_string(), format!("{name}-ref"))
361 .with_version("1.0.0".to_string());
362 c.component_type = ComponentType::MachineLearningModel;
363 c.hashes
364 .push(Hash::new(HashAlgorithm::Sha256, hash_hex.to_string()));
365 c
366 }
367
368 #[test]
369 fn verifies_against_hf_blob_named_by_sha256() {
370 let dir = tempfile::tempdir().unwrap();
371 let weights = b"fake model weights";
372 let hex = sha256_hex(weights);
373
374 let blobs = dir.path().join("blobs");
376 fs::create_dir_all(&blobs).unwrap();
377 fs::write(blobs.join(&hex), weights).unwrap();
378
379 let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
380 sbom.add_component(model_component("bert", &hex));
381
382 let report = verify_model_dir(&sbom, dir.path());
383 assert_eq!(report.total_models, 1);
384 assert_eq!(report.verified_count, 1);
385 assert_eq!(report.components[0].result, ModelVerifyResult::Verified);
386 assert!(!report.has_failures());
387 }
388
389 #[test]
390 fn verifies_against_direct_filename() {
391 let dir = tempfile::tempdir().unwrap();
392 let weights = b"safetensors bytes";
393 let hex = sha256_hex(weights);
394 fs::write(dir.path().join("model.safetensors"), weights).unwrap();
395
396 let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
397 sbom.add_component(model_component("bert", &hex));
398
399 let report = verify_model_dir(&sbom, dir.path());
400 assert_eq!(report.verified_count, 1);
401 assert_eq!(
402 report.components[0].file.as_deref(),
403 Some("model.safetensors")
404 );
405 }
406
407 #[test]
408 fn detects_tampering_as_mismatch() {
409 let dir = tempfile::tempdir().unwrap();
410 fs::write(dir.path().join("model.safetensors"), b"tampered bytes").unwrap();
412 let claimed = sha256_hex(b"original bytes");
413
414 let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
415 sbom.add_component(model_component("bert", &claimed));
416
417 let report = verify_model_dir(&sbom, dir.path());
418 assert_eq!(report.mismatch_count, 1);
419 assert_eq!(report.components[0].result, ModelVerifyResult::Mismatch);
420 assert!(report.has_failures());
421 }
422
423 #[test]
424 fn reports_missing_when_no_file_found() {
425 let dir = tempfile::tempdir().unwrap();
426 let hex = sha256_hex(b"weights that are not on disk");
427
428 let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
429 sbom.add_component(model_component("bert", &hex));
430
431 let report = verify_model_dir(&sbom, dir.path());
432 assert_eq!(report.missing_count, 1);
433 assert_eq!(report.components[0].result, ModelVerifyResult::Missing);
434 }
435
436 #[test]
437 fn reports_no_hash_when_only_weak_hash_present() {
438 let dir = tempfile::tempdir().unwrap();
439 let mut c = Component::new("bert".to_string(), "bert-ref".to_string());
440 c.component_type = ComponentType::MachineLearningModel;
441 c.hashes
442 .push(Hash::new(HashAlgorithm::Md5, "deadbeef".to_string()));
443
444 let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
445 sbom.add_component(c);
446
447 let report = verify_model_dir(&sbom, dir.path());
448 assert_eq!(report.no_hash_count, 1);
449 assert_eq!(report.components[0].result, ModelVerifyResult::NoHash);
450 }
451
452 #[cfg(unix)]
453 #[test]
454 fn does_not_follow_symlink_escaping_model_dir() {
455 use std::os::unix::fs::symlink;
456
457 let outside = tempfile::tempdir().unwrap();
459 let weights = b"weights that live outside the model dir";
460 let hex = sha256_hex(weights);
461 let secret = outside.path().join("model.safetensors");
462 fs::write(&secret, weights).unwrap();
463
464 let model_dir = tempfile::tempdir().unwrap();
468 symlink(&secret, model_dir.path().join("model.safetensors")).unwrap();
469
470 let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
471 sbom.add_component(model_component("escape", &hex));
472
473 let report = verify_model_dir(&sbom, model_dir.path());
474 assert_eq!(report.total_models, 1);
475 assert_eq!(
476 report.verified_count, 0,
477 "a symlink escaping the model dir must not be followed/verified"
478 );
479 assert_eq!(
480 report.components[0].result,
481 ModelVerifyResult::Missing,
482 "out-of-tree symlink target is treated as no in-tree file found"
483 );
484 }
485
486 #[cfg(unix)]
487 #[test]
488 fn follows_intra_tree_symlink_like_hf_cache() {
489 use std::os::unix::fs::symlink;
490
491 let dir = tempfile::tempdir().unwrap();
495 let weights = b"in-tree hf blob bytes";
496 let hex = sha256_hex(weights);
497
498 let blobs = dir.path().join("blobs");
499 let snapshots = dir.path().join("snapshots").join("main");
500 fs::create_dir_all(&blobs).unwrap();
501 fs::create_dir_all(&snapshots).unwrap();
502 let blob = blobs.join(&hex);
503 fs::write(&blob, weights).unwrap();
504 symlink(&blob, snapshots.join("model.safetensors")).unwrap();
505
506 let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
507 sbom.add_component(model_component("bert", &hex));
508
509 let report = verify_model_dir(&sbom, dir.path());
510 assert_eq!(
511 report.verified_count, 1,
512 "intra-tree HF snapshot→blob symlink must still verify"
513 );
514 }
515
516 #[test]
517 fn ignores_non_model_components() {
518 let dir = tempfile::tempdir().unwrap();
519 let mut c = Component::new("lib".to_string(), "lib-ref".to_string());
520 c.component_type = ComponentType::Library;
521 c.hashes
522 .push(Hash::new(HashAlgorithm::Sha256, "a".repeat(64)));
523
524 let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
525 sbom.add_component(c);
526
527 let report = verify_model_dir(&sbom, dir.path());
528 assert_eq!(report.total_models, 0, "library components are not models");
529 }
530}