Skip to main content

viewport_lib/renderer/
shader_hashes.rs

1//! Shader hash pinning for runtime integrity validation.
2//!
3//! Provides FNV-1a 64-bit hashing for all 16 WGSL shaders in the viewport library.
4//! Use `current_shader_hashes()` to snapshot hashes at build time, then
5//! `validate_shader_hashes()` at startup to detect accidental shader changes.
6
7// ---------------------------------------------------------------------------
8// FNV-1a 64-bit hash
9// ---------------------------------------------------------------------------
10
11const FNV_OFFSET: u64 = 0xcbf29ce484222325;
12const FNV_PRIME: u64 = 0x00000100000001b3;
13
14/// Compute a deterministic FNV-1a 64-bit hash of a byte slice.
15pub fn fnv1a_hash(data: &[u8]) -> u64 {
16    let mut hash = FNV_OFFSET;
17    for &byte in data {
18        hash ^= byte as u64;
19        hash = hash.wrapping_mul(FNV_PRIME);
20    }
21    hash
22}
23
24// ---------------------------------------------------------------------------
25// Shader catalog
26// ---------------------------------------------------------------------------
27
28/// One entry in the shader catalog.
29pub struct ShaderEntry {
30    /// Human-readable shader name (filename without path).
31    pub name: &'static str,
32    /// Full WGSL source as embedded at compile time.
33    pub source: &'static str,
34}
35
36/// All 17 shaders embedded via `include_str!`.
37///
38/// Order matches the filesystem order in `src/shaders/`.
39pub const SHADERS: &[ShaderEntry] = &[
40    ShaderEntry {
41        name: "axes_overlay.wgsl",
42        source: include_str!("../shaders/axes_overlay.wgsl"),
43    },
44    ShaderEntry {
45        name: "bloom_blur.wgsl",
46        source: include_str!("../shaders/bloom_blur.wgsl"),
47    },
48    ShaderEntry {
49        name: "bloom_threshold.wgsl",
50        source: include_str!("../shaders/bloom_threshold.wgsl"),
51    },
52    ShaderEntry {
53        name: "contact_shadow.wgsl",
54        source: include_str!("../shaders/contact_shadow.wgsl"),
55    },
56    ShaderEntry {
57        name: "fxaa.wgsl",
58        source: include_str!("../shaders/fxaa.wgsl"),
59    },
60    ShaderEntry {
61        name: "grid.wgsl",
62        source: include_str!("../shaders/grid.wgsl"),
63    },
64    ShaderEntry {
65        name: "gizmo.wgsl",
66        source: include_str!("../shaders/gizmo.wgsl"),
67    },
68    ShaderEntry {
69        name: "mesh.wgsl",
70        source: include_str!("../shaders/mesh.wgsl"),
71    },
72    ShaderEntry {
73        name: "mesh_instanced.wgsl",
74        source: include_str!("../shaders/mesh_instanced.wgsl"),
75    },
76    ShaderEntry {
77        name: "outline.wgsl",
78        source: include_str!("../shaders/outline.wgsl"),
79    },
80    ShaderEntry {
81        name: "outline_composite.wgsl",
82        source: include_str!("../shaders/outline_composite.wgsl"),
83    },
84    ShaderEntry {
85        name: "overlay.wgsl",
86        source: include_str!("../shaders/overlay.wgsl"),
87    },
88    ShaderEntry {
89        name: "shadow.wgsl",
90        source: include_str!("../shaders/shadow.wgsl"),
91    },
92    ShaderEntry {
93        name: "shadow_instanced.wgsl",
94        source: include_str!("../shaders/shadow_instanced.wgsl"),
95    },
96    ShaderEntry {
97        name: "ssao.wgsl",
98        source: include_str!("../shaders/ssao.wgsl"),
99    },
100    ShaderEntry {
101        name: "ssao_blur.wgsl",
102        source: include_str!("../shaders/ssao_blur.wgsl"),
103    },
104    ShaderEntry {
105        name: "tone_map.wgsl",
106        source: include_str!("../shaders/tone_map.wgsl"),
107    },
108];
109
110// ---------------------------------------------------------------------------
111// Public API
112// ---------------------------------------------------------------------------
113
114/// Result of a shader hash validation run.
115pub struct ShaderValidation {
116    /// Number of shaders whose hashes matched expected values.
117    pub valid: usize,
118    /// Names of shaders whose hashes did not match.
119    pub mismatched: Vec<String>,
120}
121
122/// Return the current FNV-1a hash for every shader in the catalog.
123///
124/// Returns `(name, hash)` pairs in catalog order.
125/// Use this to snapshot the expected hashes for later validation.
126pub fn current_shader_hashes() -> Vec<(&'static str, u64)> {
127    SHADERS
128        .iter()
129        .map(|s| (s.name, fnv1a_hash(s.source.as_bytes())))
130        .collect()
131}
132
133/// Compare `expected` hashes against the current compiled-in shader sources.
134///
135/// Logs a `tracing::warn!` for each mismatch.
136/// Returns `ShaderValidation` with the count of matching shaders and names of
137/// mismatched ones.
138///
139/// Shaders not present in `expected` are skipped (not counted as mismatched).
140pub fn validate_shader_hashes(expected: &[(&str, u64)]) -> ShaderValidation {
141    let current: std::collections::HashMap<&str, u64> =
142        current_shader_hashes().into_iter().collect();
143
144    let mut valid = 0usize;
145    let mut mismatched = Vec::new();
146
147    for (name, exp_hash) in expected {
148        match current.get(name) {
149            Some(&cur_hash) if cur_hash == *exp_hash => {
150                valid += 1;
151            }
152            Some(&cur_hash) => {
153                tracing::warn!(
154                    shader = %name,
155                    expected = %exp_hash,
156                    actual = %cur_hash,
157                    "shader hash mismatch — shader may have been modified unexpectedly"
158                );
159                mismatched.push(name.to_string());
160            }
161            None => {
162                tracing::warn!(
163                    shader = %name,
164                    "shader not found in catalog during validation"
165                );
166                mismatched.push(name.to_string());
167            }
168        }
169    }
170
171    ShaderValidation { valid, mismatched }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    #[test]
179    fn test_fnv1a_hash_deterministic() {
180        let h1 = fnv1a_hash(b"hello");
181        let h2 = fnv1a_hash(b"hello");
182        assert_eq!(h1, h2);
183    }
184
185    #[test]
186    fn test_fnv1a_hash_different_inputs_differ() {
187        let h1 = fnv1a_hash(b"hello");
188        let h2 = fnv1a_hash(b"world");
189        assert_ne!(h1, h2);
190    }
191
192    #[test]
193    fn test_current_shader_hashes_returns_17_entries() {
194        let hashes = current_shader_hashes();
195        assert_eq!(
196            hashes.len(),
197            17,
198            "expected 17 shaders, got {}",
199            hashes.len()
200        );
201    }
202
203    #[test]
204    fn test_current_shader_hashes_all_names_present() {
205        let hashes = current_shader_hashes();
206        let names: Vec<&str> = hashes.iter().map(|(n, _)| *n).collect();
207        assert!(names.contains(&"mesh.wgsl"));
208        assert!(names.contains(&"gizmo.wgsl"));
209        assert!(names.contains(&"shadow_instanced.wgsl"));
210        assert!(names.contains(&"tone_map.wgsl"));
211    }
212
213    #[test]
214    fn test_validate_shader_hashes_all_correct_passes() {
215        let hashes = current_shader_hashes();
216        let expected: Vec<(&str, u64)> = hashes.iter().map(|(n, h)| (*n, *h)).collect();
217        let result = validate_shader_hashes(&expected);
218        assert_eq!(result.valid, 17);
219        assert!(result.mismatched.is_empty());
220    }
221
222    #[test]
223    fn test_validate_shader_hashes_wrong_hash_reports_mismatch() {
224        let wrong_hash: Vec<(&str, u64)> = vec![("mesh.wgsl", 0xdeadbeefcafe1234)];
225        let result = validate_shader_hashes(&wrong_hash);
226        assert_eq!(result.valid, 0);
227        assert_eq!(result.mismatched.len(), 1);
228        assert_eq!(result.mismatched[0], "mesh.wgsl");
229    }
230
231    #[test]
232    fn test_validate_shader_hashes_partial_expected() {
233        let hashes = current_shader_hashes();
234        // Only validate the first 3 shaders
235        let expected: Vec<(&str, u64)> = hashes[..3].iter().map(|(n, h)| (*n, *h)).collect();
236        let result = validate_shader_hashes(&expected);
237        assert_eq!(result.valid, 3);
238        assert!(result.mismatched.is_empty());
239    }
240}