Skip to main content

sp1_prover/
utils.rs

1use std::{
2    fs::{self, File},
3    io::Read,
4    iter::{once, Skip, Take},
5    sync::Arc,
6};
7
8use rand::{rngs::OsRng, RngCore};
9
10use itertools::Itertools;
11use slop_symmetric::CryptographicHasher;
12use sp1_core_executor::Program;
13use sp1_core_executor_runner::MinimalExecutorRunner;
14use sp1_core_machine::io::SP1Stdin;
15use sp1_primitives::{poseidon2_hasher, SP1Field};
16use sp1_recursion_circuit::machine::RootPublicValues;
17use sp1_recursion_executor::{RecursionPublicValues, NUM_PV_ELMS_TO_HASH};
18
19use crate::SP1CoreProofData;
20
21/// Compute the digest of the public values.
22pub fn recursion_public_values_digest(
23    public_values: &RecursionPublicValues<SP1Field>,
24) -> [SP1Field; 8] {
25    let hasher = poseidon2_hasher();
26    hasher.hash_slice(&public_values.as_array()[0..NUM_PV_ELMS_TO_HASH])
27}
28
29pub fn root_public_values_digest(public_values: &RootPublicValues<SP1Field>) -> [SP1Field; 8] {
30    let hasher = poseidon2_hasher();
31    let input = (*public_values.sp1_vk_digest())
32        .into_iter()
33        .chain(
34            (*public_values.committed_value_digest()).into_iter().flat_map(|word| word.into_iter()),
35        )
36        .chain(once(*public_values.exit_code()))
37        .chain(*public_values.vk_root())
38        .chain(*public_values.proof_nonce())
39        .collect::<Vec<_>>();
40    hasher.hash_slice(&input)
41}
42
43pub fn is_root_public_values_valid(public_values: &RootPublicValues<SP1Field>) -> bool {
44    let expected_digest = root_public_values_digest(public_values);
45    for (value, expected) in public_values.digest().iter().copied().zip_eq(expected_digest) {
46        if value != expected {
47            return false;
48        }
49    }
50    true
51}
52
53/// Assert that the digest of the public values is correct.
54pub fn is_recursion_public_values_valid(public_values: &RecursionPublicValues<SP1Field>) -> bool {
55    let expected_digest = recursion_public_values_digest(public_values);
56    for (value, expected) in public_values.digest.iter().copied().zip_eq(expected_digest) {
57        if value != expected {
58            return false;
59        }
60    }
61    true
62}
63
64impl SP1CoreProofData {
65    pub fn save(&self, path: &str) -> Result<(), std::io::Error> {
66        let data = serde_json::to_string(self).unwrap();
67        fs::write(path, data).unwrap();
68        Ok(())
69    }
70}
71
72/// Get the number of cycles for a given program.
73pub fn get_cycles(elf: &[u8], stdin: &SP1Stdin) -> u64 {
74    let program = Program::from(elf).unwrap();
75    let mut executor = MinimalExecutorRunner::simple(Arc::new(program));
76    for buf in &stdin.buffer {
77        executor.with_input(buf);
78    }
79    while executor.execute_chunk().is_some() {}
80    executor.global_clk()
81}
82
83/// Load an ELF file from a given path.
84pub fn load_elf(path: &str) -> Result<Vec<u8>, std::io::Error> {
85    let mut elf_code = Vec::new();
86    File::open(path)?.read_to_end(&mut elf_code)?;
87    Ok(elf_code)
88}
89
90pub fn words_to_bytes<T: Copy>(words: &[[T; 4]; 8]) -> Vec<T> {
91    words.iter().flat_map(|word| word.iter()).copied().collect()
92}
93
94/// Utility method for converting 32 big-endian bytes back into eight u32 words.
95pub fn bytes_to_words_be(bytes: &[u8; 32]) -> [u32; 8] {
96    let mut words = [0u32; 8];
97    for i in 0..8 {
98        let chunk: [u8; 4] = bytes[i * 4..(i + 1) * 4].try_into().unwrap();
99        words[i] = u32::from_be_bytes(chunk);
100    }
101    words
102}
103
104pub trait MaybeTakeIterator<I: Iterator>: Iterator<Item = I::Item> {
105    fn maybe_skip(self, bound: Option<usize>) -> RangedIterator<Self>
106    where
107        Self: Sized,
108    {
109        match bound {
110            Some(bound) => RangedIterator::Skip(self.skip(bound)),
111            None => RangedIterator::Unbounded(self),
112        }
113    }
114
115    fn maybe_take(self, bound: Option<usize>) -> RangedIterator<Self>
116    where
117        Self: Sized,
118    {
119        match bound {
120            Some(bound) => RangedIterator::Take(self.take(bound)),
121            None => RangedIterator::Unbounded(self),
122        }
123    }
124}
125
126impl<I: Iterator> MaybeTakeIterator<I> for I {}
127
128pub enum RangedIterator<I> {
129    Unbounded(I),
130    Skip(Skip<I>),
131    Take(Take<I>),
132    Range(Take<Skip<I>>),
133}
134
135impl<I: Iterator> Iterator for RangedIterator<I> {
136    type Item = I::Item;
137
138    fn next(&mut self) -> Option<Self::Item> {
139        match self {
140            RangedIterator::Unbounded(unbounded) => unbounded.next(),
141            RangedIterator::Skip(skip) => skip.next(),
142            RangedIterator::Take(take) => take.next(),
143            RangedIterator::Range(range) => range.next(),
144        }
145    }
146}
147
148/// Generate a 128-bit nonce using OsRng.
149pub fn generate_nonce() -> [u32; 4] {
150    let mut bytes = [0u8; 16];
151    OsRng.fill_bytes(&mut bytes);
152
153    [
154        u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]),
155        u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]),
156        u32::from_be_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]),
157        u32::from_be_bytes([bytes[12], bytes[13], bytes[14], bytes[15]]),
158    ]
159}