1#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
2pub struct BenchConfig {
3 pub model_commit: String,
4 pub weights_hash: String,
5 pub lambda: f32,
6 pub tau: usize,
7 pub eps: f32,
8 pub compiler_flags: String,
9}
10
11pub fn config_hash(config: &BenchConfig) -> String {
13 let json = serde_json::to_string(config).expect("BenchConfig serializable");
14 sha256(json.as_bytes()).iter().map(|b| format!("{b:02x}")).collect()
15}
16
17fn sha256(data: &[u8]) -> [u8; 32] {
18 #[rustfmt::skip]
19 const K: [u32; 64] = [
20 0x428a2f98,0x71374491,0xb5c0fbcf,0xe9b5dba5,0x3956c25b,0x59f111f1,0x923f82a4,0xab1c5ed5,
21 0xd807aa98,0x12835b01,0x243185be,0x550c7dc3,0x72be5d74,0x80deb1fe,0x9bdc06a7,0xc19bf174,
22 0xe49b69c1,0xefbe4786,0x0fc19dc6,0x240ca1cc,0x2de92c6f,0x4a7484aa,0x5cb0a9dc,0x76f988da,
23 0x983e5152,0xa831c66d,0xb00327c8,0xbf597fc7,0xc6e00bf3,0xd5a79147,0x06ca6351,0x14292967,
24 0x27b70a85,0x2e1b2138,0x4d2c6dfc,0x53380d13,0x650a7354,0x766a0abb,0x81c2c92e,0x92722c85,
25 0xa2bfe8a1,0xa81a664b,0xc24b8b70,0xc76c51a3,0xd192e819,0xd6990624,0xf40e3585,0x106aa070,
26 0x19a4c116,0x1e376c08,0x2748774c,0x34b0bcb5,0x391c0cb3,0x4ed8aa4a,0x5b9cca4f,0x682e6ff3,
27 0x748f82ee,0x78a5636f,0x84c87814,0x8cc70208,0x90befffa,0xa4506ceb,0xbef9a3f7,0xc67178f2,
28 ];
29 let mut h: [u32; 8] = [
30 0x6a09e667,0xbb67ae85,0x3c6ef372,0xa54ff53a,0x510e527f,0x9b05688c,0x1f83d9ab,0x5be0cd19,
31 ];
32 let bit_len = (data.len() as u64) * 8;
33 let mut msg = data.to_vec();
34 msg.push(0x80);
35 while msg.len() % 64 != 56 { msg.push(0); }
36 msg.extend_from_slice(&bit_len.to_be_bytes());
37
38 for chunk in msg.chunks_exact(64) {
39 let mut w = [0u32; 64];
40 for i in 0..16 {
41 w[i] = u32::from_be_bytes([chunk[4*i], chunk[4*i+1], chunk[4*i+2], chunk[4*i+3]]);
42 }
43 for i in 16..64 {
44 let s0 = w[i-15].rotate_right(7) ^ w[i-15].rotate_right(18) ^ (w[i-15] >> 3);
45 let s1 = w[i-2].rotate_right(17) ^ w[i-2].rotate_right(19) ^ (w[i-2] >> 10);
46 w[i] = w[i-16].wrapping_add(s0).wrapping_add(w[i-7]).wrapping_add(s1);
47 }
48 let (mut a,mut b,mut c,mut d,mut e,mut f,mut g,mut hh) =
49 (h[0],h[1],h[2],h[3],h[4],h[5],h[6],h[7]);
50 for i in 0..64 {
51 let s1 = e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25);
52 let ch = (e & f) ^ (!e & g);
53 let t1 = hh.wrapping_add(s1).wrapping_add(ch).wrapping_add(K[i]).wrapping_add(w[i]);
54 let s0 = a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22);
55 let maj = (a & b) ^ (a & c) ^ (b & c);
56 let t2 = s0.wrapping_add(maj);
57 hh = g; g = f; f = e; e = d.wrapping_add(t1);
58 d = c; c = b; b = a; a = t1.wrapping_add(t2);
59 }
60 for (i, v) in [a,b,c,d,e,f,g,hh].iter().enumerate() { h[i] = h[i].wrapping_add(*v); }
61 }
62 let mut out = [0u8; 32];
63 for (i, v) in h.iter().enumerate() { out[4*i..4*i+4].copy_from_slice(&v.to_be_bytes()); }
64 out
65}
66
67#[cfg(test)]
68mod tests {
69 use super::*;
70 fn hex(data: &[u8]) -> String { sha256(data).iter().map(|b| format!("{b:02x}")).collect() }
71
72 #[test] fn sha_empty() {
73 assert_eq!(hex(b""), "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855");
74 }
75 #[test] fn sha_abc() {
76 assert_eq!(hex(b"abc"), "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad");
77 }
78 #[test] fn deterministic() {
79 let c = BenchConfig { model_commit: "a".into(), weights_hash: "b".into(),
80 lambda: 0.1, tau: 64, eps: 1e-6, compiler_flags: "-O3".into() };
81 let (h1, h2) = (config_hash(&c), config_hash(&c));
82 assert_eq!(h1, h2);
83 assert_eq!(h1.len(), 64);
84 }
85 #[test] fn varies() {
86 let mk = |s: &str| BenchConfig { model_commit: s.into(), weights_hash: "x".into(),
87 lambda: 0.1, tau: 64, eps: 1e-6, compiler_flags: "".into() };
88 assert_ne!(config_hash(&mk("a")), config_hash(&mk("b")));
89 }
90}