sp1_recursion_circuit/basefold/
tcs.rs1use crate::{basefold::merkle_tree::verify, hash::FieldHasherVariable, CircuitConfig};
2use itertools::Itertools;
3use slop_algebra::AbstractField;
4use slop_tensor::Tensor;
5use sp1_primitives::SP1Field;
6use sp1_recursion_compiler::ir::{Builder, Felt, IrIter};
7use std::marker::PhantomData;
8
9pub struct RecursiveTensorCsOpening<CommitmentVariable> {
11 pub values: Tensor<Felt<SP1Field>>,
13 pub proof: Tensor<CommitmentVariable>,
15
16 pub merkle_root: CommitmentVariable,
17
18 pub log_height: usize,
19 pub width: usize,
20}
21
22#[derive(Debug, Copy, PartialEq, Eq)]
23pub struct RecursiveMerkleTreeTcs<C, M>(pub PhantomData<(C, M)>);
24
25impl<C, M> Clone for RecursiveMerkleTreeTcs<C, M> {
26 fn clone(&self) -> Self {
27 Self(PhantomData)
28 }
29}
30
31impl<C, M> RecursiveMerkleTreeTcs<C, M>
32where
33 C: CircuitConfig,
34 M: FieldHasherVariable<C>,
35{
36 pub fn verify_tensor_openings(
37 builder: &mut Builder<C>,
38 commit: &M::DigestVariable,
39 indices: &[Vec<C::Bit>],
40 opening: &RecursiveTensorCsOpening<M::DigestVariable>,
41 ) {
42 let chunk_size = indices.len().div_ceil(8);
43
44 let log_height = builder.constant(SP1Field::from_canonical_usize(opening.log_height));
45 let width = builder.constant(SP1Field::from_canonical_usize(opening.width));
46 let hash = M::hash(builder, &[log_height, width]);
47 let expected_commit = M::compress(builder, [opening.merkle_root, hash]);
48 M::assert_digest_eq(builder, expected_commit, *commit);
49
50 indices
51 .iter()
52 .zip_eq(opening.proof.split())
53 .map(|(x, y)| (x.clone(), y.as_slice().to_vec()))
54 .collect::<Vec<_>>()
55 .chunks(chunk_size)
56 .enumerate()
57 .ir_par_map_collect::<Vec<_>, _, _>(builder, |builder, (i, chunk)| {
58 for (j, (index, path)) in chunk.iter().enumerate() {
59 let claimed_values_slices =
60 opening.values.get(i * chunk_size + j).unwrap().as_slice().to_vec();
61
62 let path = path.as_slice().to_vec();
63 let digest = M::hash(builder, &claimed_values_slices);
64
65 verify::<C, M>(builder, path, index.to_vec(), digest, opening.merkle_root);
66 }
67 });
68 }
69}
70
71#[cfg(test)]
72mod tests {
73 use rand::{thread_rng, Rng};
74 use slop_commit::Message;
75 use slop_merkle_tree::{ComputeTcsOpenings, MerkleTreeOpeningAndProof, TensorCsProver};
76 use sp1_hypercube::inner_perm;
77 use sp1_recursion_compiler::circuit::AsmConfig;
78 use std::sync::Arc;
79
80 use slop_algebra::extension::BinomialExtensionField;
81 use sp1_primitives::{SP1DiffusionMatrix, SP1GlobalContext};
82
83 use crate::witness::Witnessable;
84
85 use super::*;
86 use itertools::Itertools;
87 use slop_tensor::Tensor;
88 use sp1_hypercube::prover::SP1MerkleTreeProver;
89 use sp1_recursion_compiler::circuit::{AsmBuilder, AsmCompiler};
90 use sp1_recursion_executor::Executor;
91
92 use sp1_primitives::SP1Field;
93 type F = SP1Field;
94 type EF = BinomialExtensionField<SP1Field, 4>;
95
96 #[test]
97 fn test_merkle_proof() {
98 let mut rng = thread_rng();
99
100 let height = rng.gen_range(500..2000);
101 let width = rng.gen_range(15..30);
102 let num_tensors = rng.gen_range(5..15);
103
104 let num_indices = rng.gen_range(2..10);
105
106 let tensors = (0..num_tensors)
107 .map(|_| Tensor::<SP1Field>::rand(&mut rng, [height, width]))
108 .collect::<Message<_>>();
109
110 let prover = SP1MerkleTreeProver::default();
111 let (root, data) = prover.commit_tensors(tensors.clone()).unwrap();
112
113 let indices = (0..num_indices).map(|_| rng.gen_range(0..height)).collect_vec();
114 let proof = prover.prove_openings_at_indices(data, &indices).unwrap();
115 let openings = prover.compute_openings_at_indices(tensors, &indices);
116 let opening: MerkleTreeOpeningAndProof<SP1GlobalContext> =
117 MerkleTreeOpeningAndProof { values: openings, proof };
118
119 let bit_len = height.next_power_of_two().ilog2();
120
121 let mut builder = AsmBuilder::default();
122 let mut witness_stream = Vec::new();
123
124 let mut index_bits = Vec::new();
125 for index in indices {
126 let bits = (0..bit_len).map(|i| (index >> i) & 1 == 1).collect_vec();
127 Witnessable::<AsmConfig>::write(&bits, &mut witness_stream);
128 let bits = bits.read(&mut builder);
129 index_bits.push(bits);
130 }
131
132 Witnessable::<AsmConfig>::write(&root, &mut witness_stream);
133 let root = root.read(&mut builder);
134 Witnessable::<AsmConfig>::write(&opening, &mut witness_stream);
135 let opening = opening.read(&mut builder);
136
137 RecursiveMerkleTreeTcs::<AsmConfig, SP1GlobalContext>::verify_tensor_openings(
138 &mut builder,
139 &root,
140 &index_bits,
141 &opening,
142 );
143
144 let block = builder.into_root_block();
145 let mut compiler = AsmCompiler::default();
146 let program = Arc::new(compiler.compile_inner(block).validate().unwrap());
147 let mut executor =
148 Executor::<F, EF, SP1DiffusionMatrix>::new(program.clone(), inner_perm());
149 executor.witness_stream = witness_stream.into();
150 executor.run().unwrap();
151 }
152
153 #[test]
154 fn test_invalid_merkle_proof() {
155 let mut rng = thread_rng();
156
157 let height = rng.gen_range(500..2000);
158 let width = rng.gen_range(15..30);
159 let num_tensors = rng.gen_range(5..15);
160
161 let num_indices = rng.gen_range(2..10);
162
163 let tensors = (0..num_tensors)
164 .map(|_| Tensor::<SP1Field>::rand(&mut rng, [height, width]))
165 .collect::<Message<_>>();
166
167 let prover = SP1MerkleTreeProver::default();
168 let (root, data) = prover.commit_tensors(tensors.clone()).unwrap();
169
170 let indices = (0..num_indices).map(|_| rng.gen_range(0..height)).collect_vec();
171 let proof = prover.prove_openings_at_indices(data, &indices).unwrap();
172 let openings = prover.compute_openings_at_indices(tensors, &indices);
173 let opening: MerkleTreeOpeningAndProof<SP1GlobalContext> =
174 MerkleTreeOpeningAndProof { values: openings, proof };
175
176 let bit_len = height.next_power_of_two().ilog2();
177
178 let mut builder = AsmBuilder::default();
179 let mut witness_stream = Vec::new();
180
181 let mut index_bits = Vec::new();
182 for index in indices {
183 let bits = (0..bit_len)
184 .map(|i| if i == 0 { (index >> i) & 1 == 0 } else { (index >> i) & 1 == 1 })
185 .collect_vec();
186 Witnessable::<AsmConfig>::write(&bits, &mut witness_stream);
187 let bits = bits.read(&mut builder);
188 index_bits.push(bits);
189 }
190
191 Witnessable::<AsmConfig>::write(&root, &mut witness_stream);
192 let root = root.read(&mut builder);
193 Witnessable::<AsmConfig>::write(&opening, &mut witness_stream);
194 let opening = opening.read(&mut builder);
195
196 RecursiveMerkleTreeTcs::<AsmConfig, SP1GlobalContext>::verify_tensor_openings(
197 &mut builder,
198 &root,
199 &index_bits,
200 &opening,
201 );
202
203 let block = builder.into_root_block();
204 let mut compiler = AsmCompiler::default();
205 let program = Arc::new(compiler.compile_inner(block).validate().unwrap());
206 let mut executor =
207 Executor::<F, EF, SP1DiffusionMatrix>::new(program.clone(), inner_perm());
208 executor.witness_stream = witness_stream.into();
209 executor.run().expect_err("merkle proof should not verify");
210 }
211}