Skip to main content

sp1_recursion_circuit/basefold/
tcs.rs

1use crate::{basefold::merkle_tree::verify, hash::FieldHasherVariable, CircuitConfig};
2use itertools::Itertools;
3use slop_algebra::AbstractField;
4use slop_tensor::Tensor;
5use sp1_primitives::SP1Field;
6use sp1_recursion_compiler::ir::{Builder, Felt, IrIter};
7use std::marker::PhantomData;
8
9/// An opening of a tensor commitment scheme.
10pub struct RecursiveTensorCsOpening<CommitmentVariable> {
11    /// The claimed values of the opening.
12    pub values: Tensor<Felt<SP1Field>>,
13    /// The proof of the opening.
14    pub proof: Tensor<CommitmentVariable>,
15
16    pub merkle_root: CommitmentVariable,
17
18    pub log_height: usize,
19    pub width: usize,
20}
21
22#[derive(Debug, Copy, PartialEq, Eq)]
23pub struct RecursiveMerkleTreeTcs<C, M>(pub PhantomData<(C, M)>);
24
25impl<C, M> Clone for RecursiveMerkleTreeTcs<C, M> {
26    fn clone(&self) -> Self {
27        Self(PhantomData)
28    }
29}
30
31impl<C, M> RecursiveMerkleTreeTcs<C, M>
32where
33    C: CircuitConfig,
34    M: FieldHasherVariable<C>,
35{
36    pub fn verify_tensor_openings(
37        builder: &mut Builder<C>,
38        commit: &M::DigestVariable,
39        indices: &[Vec<C::Bit>],
40        opening: &RecursiveTensorCsOpening<M::DigestVariable>,
41    ) {
42        let chunk_size = indices.len().div_ceil(8);
43
44        let log_height = builder.constant(SP1Field::from_canonical_usize(opening.log_height));
45        let width = builder.constant(SP1Field::from_canonical_usize(opening.width));
46        let hash = M::hash(builder, &[log_height, width]);
47        let expected_commit = M::compress(builder, [opening.merkle_root, hash]);
48        M::assert_digest_eq(builder, expected_commit, *commit);
49
50        indices
51            .iter()
52            .zip_eq(opening.proof.split())
53            .map(|(x, y)| (x.clone(), y.as_slice().to_vec()))
54            .collect::<Vec<_>>()
55            .chunks(chunk_size)
56            .enumerate()
57            .ir_par_map_collect::<Vec<_>, _, _>(builder, |builder, (i, chunk)| {
58                for (j, (index, path)) in chunk.iter().enumerate() {
59                    let claimed_values_slices =
60                        opening.values.get(i * chunk_size + j).unwrap().as_slice().to_vec();
61
62                    let path = path.as_slice().to_vec();
63                    let digest = M::hash(builder, &claimed_values_slices);
64
65                    verify::<C, M>(builder, path, index.to_vec(), digest, opening.merkle_root);
66                }
67            });
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use rand::{thread_rng, Rng};
74    use slop_commit::Message;
75    use slop_merkle_tree::{ComputeTcsOpenings, MerkleTreeOpeningAndProof, TensorCsProver};
76    use sp1_hypercube::inner_perm;
77    use sp1_recursion_compiler::circuit::AsmConfig;
78    use std::sync::Arc;
79
80    use slop_algebra::extension::BinomialExtensionField;
81    use sp1_primitives::{SP1DiffusionMatrix, SP1GlobalContext};
82
83    use crate::witness::Witnessable;
84
85    use super::*;
86    use itertools::Itertools;
87    use slop_tensor::Tensor;
88    use sp1_hypercube::prover::SP1MerkleTreeProver;
89    use sp1_recursion_compiler::circuit::{AsmBuilder, AsmCompiler};
90    use sp1_recursion_executor::Executor;
91
92    use sp1_primitives::SP1Field;
93    type F = SP1Field;
94    type EF = BinomialExtensionField<SP1Field, 4>;
95
96    #[test]
97    fn test_merkle_proof() {
98        let mut rng = thread_rng();
99
100        let height = rng.gen_range(500..2000);
101        let width = rng.gen_range(15..30);
102        let num_tensors = rng.gen_range(5..15);
103
104        let num_indices = rng.gen_range(2..10);
105
106        let tensors = (0..num_tensors)
107            .map(|_| Tensor::<SP1Field>::rand(&mut rng, [height, width]))
108            .collect::<Message<_>>();
109
110        let prover = SP1MerkleTreeProver::default();
111        let (root, data) = prover.commit_tensors(tensors.clone()).unwrap();
112
113        let indices = (0..num_indices).map(|_| rng.gen_range(0..height)).collect_vec();
114        let proof = prover.prove_openings_at_indices(data, &indices).unwrap();
115        let openings = prover.compute_openings_at_indices(tensors, &indices);
116        let opening: MerkleTreeOpeningAndProof<SP1GlobalContext> =
117            MerkleTreeOpeningAndProof { values: openings, proof };
118
119        let bit_len = height.next_power_of_two().ilog2();
120
121        let mut builder = AsmBuilder::default();
122        let mut witness_stream = Vec::new();
123
124        let mut index_bits = Vec::new();
125        for index in indices {
126            let bits = (0..bit_len).map(|i| (index >> i) & 1 == 1).collect_vec();
127            Witnessable::<AsmConfig>::write(&bits, &mut witness_stream);
128            let bits = bits.read(&mut builder);
129            index_bits.push(bits);
130        }
131
132        Witnessable::<AsmConfig>::write(&root, &mut witness_stream);
133        let root = root.read(&mut builder);
134        Witnessable::<AsmConfig>::write(&opening, &mut witness_stream);
135        let opening = opening.read(&mut builder);
136
137        RecursiveMerkleTreeTcs::<AsmConfig, SP1GlobalContext>::verify_tensor_openings(
138            &mut builder,
139            &root,
140            &index_bits,
141            &opening,
142        );
143
144        let block = builder.into_root_block();
145        let mut compiler = AsmCompiler::default();
146        let program = Arc::new(compiler.compile_inner(block).validate().unwrap());
147        let mut executor =
148            Executor::<F, EF, SP1DiffusionMatrix>::new(program.clone(), inner_perm());
149        executor.witness_stream = witness_stream.into();
150        executor.run().unwrap();
151    }
152
153    #[test]
154    fn test_invalid_merkle_proof() {
155        let mut rng = thread_rng();
156
157        let height = rng.gen_range(500..2000);
158        let width = rng.gen_range(15..30);
159        let num_tensors = rng.gen_range(5..15);
160
161        let num_indices = rng.gen_range(2..10);
162
163        let tensors = (0..num_tensors)
164            .map(|_| Tensor::<SP1Field>::rand(&mut rng, [height, width]))
165            .collect::<Message<_>>();
166
167        let prover = SP1MerkleTreeProver::default();
168        let (root, data) = prover.commit_tensors(tensors.clone()).unwrap();
169
170        let indices = (0..num_indices).map(|_| rng.gen_range(0..height)).collect_vec();
171        let proof = prover.prove_openings_at_indices(data, &indices).unwrap();
172        let openings = prover.compute_openings_at_indices(tensors, &indices);
173        let opening: MerkleTreeOpeningAndProof<SP1GlobalContext> =
174            MerkleTreeOpeningAndProof { values: openings, proof };
175
176        let bit_len = height.next_power_of_two().ilog2();
177
178        let mut builder = AsmBuilder::default();
179        let mut witness_stream = Vec::new();
180
181        let mut index_bits = Vec::new();
182        for index in indices {
183            let bits = (0..bit_len)
184                .map(|i| if i == 0 { (index >> i) & 1 == 0 } else { (index >> i) & 1 == 1 })
185                .collect_vec();
186            Witnessable::<AsmConfig>::write(&bits, &mut witness_stream);
187            let bits = bits.read(&mut builder);
188            index_bits.push(bits);
189        }
190
191        Witnessable::<AsmConfig>::write(&root, &mut witness_stream);
192        let root = root.read(&mut builder);
193        Witnessable::<AsmConfig>::write(&opening, &mut witness_stream);
194        let opening = opening.read(&mut builder);
195
196        RecursiveMerkleTreeTcs::<AsmConfig, SP1GlobalContext>::verify_tensor_openings(
197            &mut builder,
198            &root,
199            &index_bits,
200            &opening,
201        );
202
203        let block = builder.into_root_block();
204        let mut compiler = AsmCompiler::default();
205        let program = Arc::new(compiler.compile_inner(block).validate().unwrap());
206        let mut executor =
207            Executor::<F, EF, SP1DiffusionMatrix>::new(program.clone(), inner_perm());
208        executor.witness_stream = witness_stream.into();
209        executor.run().expect_err("merkle proof should not verify");
210    }
211}