starknet_crypto/
rfc6979.rs

1use crypto_bigint::{ArrayEncoding, ByteArray, Integer, U256};
2use hmac::digest::Digest;
3use sha2::digest::{crypto_common::BlockSizeUser, FixedOutputReset, HashMarker};
4use starknet_types_core::felt::Felt;
5use zeroize::{Zeroize, Zeroizing};
6
7const EC_ORDER: U256 =
8    U256::from_be_hex("0800000000000010ffffffffffffffffb781126dcae7b2321e66a241adc64d2f");
9
10/// Deterministically generate ephemeral scalar `k` based on RFC 6979.
11///
12/// ### Parameters
13///
14/// - `message_hash`: Message hash.
15/// - `private_key`: Private key.
16/// - `seed`: Extra seed for additional entropy.
17pub fn generate_k(message_hash: &Felt, private_key: &Felt, seed: Option<&Felt>) -> Felt {
18    // The message hash padding as implemented in `cairo-lang` is not needed here. The hash is
19    // padded in `cairo-lang` only to make sure the lowest 4 bits won't get truncated, but here it's
20    // never getting truncated anyways.
21    let message_hash = U256::from_be_slice(&message_hash.to_bytes_be()).to_be_byte_array();
22    let private_key = U256::from_be_slice(&private_key.to_bytes_be());
23
24    let seed_bytes = match seed {
25        Some(seed) => seed.to_bytes_be(),
26        None => [0u8; 32],
27    };
28
29    let mut first_non_zero_index = 32;
30    for (ind, element) in seed_bytes.iter().enumerate() {
31        if *element != 0u8 {
32            first_non_zero_index = ind;
33            break;
34        }
35    }
36
37    let k = generate_k_shifted::<sha2::Sha256, _>(
38        &private_key,
39        &EC_ORDER,
40        &message_hash,
41        &seed_bytes[first_non_zero_index..],
42    );
43
44    let mut buffer = [0u8; 32];
45    buffer[..].copy_from_slice(&k.to_be_byte_array()[..]);
46
47    Felt::from_bytes_be(&buffer)
48}
49
50// Modified from upstream `rfc6979::generate_k` with a hard-coded right bit shift. The more
51// idiomatic way of doing this seems to be to implement `U252` which handles bit truncation
52// interally.
53// TODO: change to use upstream `generate_k` directly.
54#[inline]
55fn generate_k_shifted<D, I>(x: &I, n: &I, h: &ByteArray<I>, data: &[u8]) -> Zeroizing<I>
56where
57    D: Default + Digest + BlockSizeUser + FixedOutputReset + HashMarker,
58    I: ArrayEncoding + Integer + Zeroize,
59{
60    let mut x = x.to_be_byte_array();
61    let mut hmac_drbg = rfc6979::HmacDrbg::<D>::new(&x, h, data);
62    x.zeroize();
63
64    loop {
65        let mut bytes = ByteArray::<I>::default();
66        hmac_drbg.fill_bytes(&mut bytes);
67        let k = I::from_be_byte_array(bytes) >> 4;
68
69        if (!k.is_zero() & k.ct_lt(n)).into() {
70            return Zeroizing::new(k);
71        }
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use crate::test_utils::field_element_from_be_hex;
79    #[cfg(not(feature = "std"))]
80    use alloc::vec::Vec;
81
82    use serde::Deserialize;
83
84    #[derive(Deserialize)]
85    struct Rfc6979TestVecotr<'a> {
86        msg_hash: &'a str,
87        priv_key: &'a str,
88        seed: &'a str,
89        k: &'a str,
90    }
91
92    #[test]
93    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
94    fn test_generate_k_padded() {
95        // Test vectors generated from `cairo-lang`
96        test_generate_k_from_json_str(include_str!("../test-data/rfc6979_padded.json"));
97    }
98
99    #[test]
100    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
101    fn test_generate_k_not_padded() {
102        // Test vectors generated from `cairo-lang`
103        test_generate_k_from_json_str(include_str!("../test-data/rfc6979_not_padded.json"));
104    }
105
106    fn test_generate_k_from_json_str(json_str: &'static str) {
107        let test_vectors: Vec<Rfc6979TestVecotr<'_>> = serde_json::from_str(json_str).unwrap();
108
109        for test_vector in &test_vectors {
110            let msg_hash = field_element_from_be_hex(test_vector.msg_hash);
111            let priv_key = field_element_from_be_hex(test_vector.priv_key);
112            let seed = field_element_from_be_hex(test_vector.seed);
113            let expected_k = field_element_from_be_hex(test_vector.k);
114
115            let k = generate_k(&msg_hash, &priv_key, Some(&seed));
116
117            assert_eq!(k, expected_k);
118        }
119    }
120}