Skip to main content

ruvector_profiler/
config_hash.rs

1#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
2pub struct BenchConfig {
3    pub model_commit: String,
4    pub weights_hash: String,
5    pub lambda: f32,
6    pub tau: usize,
7    pub eps: f32,
8    pub compiler_flags: String,
9}
10
11/// SHA-256 hex digest of the JSON-serialised config.
12pub fn config_hash(config: &BenchConfig) -> String {
13    let json = serde_json::to_string(config).expect("BenchConfig serializable");
14    sha256(json.as_bytes())
15        .iter()
16        .map(|b| format!("{b:02x}"))
17        .collect()
18}
19
20fn sha256(data: &[u8]) -> [u8; 32] {
21    #[rustfmt::skip]
22    const K: [u32; 64] = [
23        0x428a2f98,0x71374491,0xb5c0fbcf,0xe9b5dba5,0x3956c25b,0x59f111f1,0x923f82a4,0xab1c5ed5,
24        0xd807aa98,0x12835b01,0x243185be,0x550c7dc3,0x72be5d74,0x80deb1fe,0x9bdc06a7,0xc19bf174,
25        0xe49b69c1,0xefbe4786,0x0fc19dc6,0x240ca1cc,0x2de92c6f,0x4a7484aa,0x5cb0a9dc,0x76f988da,
26        0x983e5152,0xa831c66d,0xb00327c8,0xbf597fc7,0xc6e00bf3,0xd5a79147,0x06ca6351,0x14292967,
27        0x27b70a85,0x2e1b2138,0x4d2c6dfc,0x53380d13,0x650a7354,0x766a0abb,0x81c2c92e,0x92722c85,
28        0xa2bfe8a1,0xa81a664b,0xc24b8b70,0xc76c51a3,0xd192e819,0xd6990624,0xf40e3585,0x106aa070,
29        0x19a4c116,0x1e376c08,0x2748774c,0x34b0bcb5,0x391c0cb3,0x4ed8aa4a,0x5b9cca4f,0x682e6ff3,
30        0x748f82ee,0x78a5636f,0x84c87814,0x8cc70208,0x90befffa,0xa4506ceb,0xbef9a3f7,0xc67178f2,
31    ];
32    let mut h: [u32; 8] = [
33        0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab,
34        0x5be0cd19,
35    ];
36    let bit_len = (data.len() as u64) * 8;
37    let mut msg = data.to_vec();
38    msg.push(0x80);
39    while msg.len() % 64 != 56 {
40        msg.push(0);
41    }
42    msg.extend_from_slice(&bit_len.to_be_bytes());
43
44    for chunk in msg.chunks_exact(64) {
45        let mut w = [0u32; 64];
46        for i in 0..16 {
47            w[i] = u32::from_be_bytes([
48                chunk[4 * i],
49                chunk[4 * i + 1],
50                chunk[4 * i + 2],
51                chunk[4 * i + 3],
52            ]);
53        }
54        for i in 16..64 {
55            let s0 = w[i - 15].rotate_right(7) ^ w[i - 15].rotate_right(18) ^ (w[i - 15] >> 3);
56            let s1 = w[i - 2].rotate_right(17) ^ w[i - 2].rotate_right(19) ^ (w[i - 2] >> 10);
57            w[i] = w[i - 16]
58                .wrapping_add(s0)
59                .wrapping_add(w[i - 7])
60                .wrapping_add(s1);
61        }
62        let (mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut hh) =
63            (h[0], h[1], h[2], h[3], h[4], h[5], h[6], h[7]);
64        for i in 0..64 {
65            let s1 = e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25);
66            let ch = (e & f) ^ (!e & g);
67            let t1 = hh
68                .wrapping_add(s1)
69                .wrapping_add(ch)
70                .wrapping_add(K[i])
71                .wrapping_add(w[i]);
72            let s0 = a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22);
73            let maj = (a & b) ^ (a & c) ^ (b & c);
74            let t2 = s0.wrapping_add(maj);
75            hh = g;
76            g = f;
77            f = e;
78            e = d.wrapping_add(t1);
79            d = c;
80            c = b;
81            b = a;
82            a = t1.wrapping_add(t2);
83        }
84        for (i, v) in [a, b, c, d, e, f, g, hh].iter().enumerate() {
85            h[i] = h[i].wrapping_add(*v);
86        }
87    }
88    let mut out = [0u8; 32];
89    for (i, v) in h.iter().enumerate() {
90        out[4 * i..4 * i + 4].copy_from_slice(&v.to_be_bytes());
91    }
92    out
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98    fn hex(data: &[u8]) -> String {
99        sha256(data).iter().map(|b| format!("{b:02x}")).collect()
100    }
101
102    #[test]
103    fn sha_empty() {
104        assert_eq!(
105            hex(b""),
106            "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
107        );
108    }
109    #[test]
110    fn sha_abc() {
111        assert_eq!(
112            hex(b"abc"),
113            "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"
114        );
115    }
116    #[test]
117    fn deterministic() {
118        let c = BenchConfig {
119            model_commit: "a".into(),
120            weights_hash: "b".into(),
121            lambda: 0.1,
122            tau: 64,
123            eps: 1e-6,
124            compiler_flags: "-O3".into(),
125        };
126        let (h1, h2) = (config_hash(&c), config_hash(&c));
127        assert_eq!(h1, h2);
128        assert_eq!(h1.len(), 64);
129    }
130    #[test]
131    fn varies() {
132        let mk = |s: &str| BenchConfig {
133            model_commit: s.into(),
134            weights_hash: "x".into(),
135            lambda: 0.1,
136            tau: 64,
137            eps: 1e-6,
138            compiler_flags: "".into(),
139        };
140        assert_ne!(config_hash(&mk("a")), config_hash(&mk("b")));
141    }
142}