sp1_recursion_machine/
test.rs1use std::sync::Arc;
2
3use slop_algebra::extension::BinomialExtensionField;
4use slop_basefold::FriConfig;
5use sp1_hypercube::{
6 inner_perm, prover::simple_prover, Machine, MachineProof, MachineVerifier,
7 MachineVerifierConfigError, SP1InnerPcs, SP1PcsProofInner, ShardVerifier,
8};
9use sp1_primitives::{
10 fri_params::{unique_decoding_queries, SP1_PROOF_OF_WORK_BITS},
11 SP1DiffusionMatrix, SP1Field, SP1GlobalContext,
12};
13use sp1_recursion_executor::{
14 linear_program, Block, ExecutionRecord, Executor, Instruction, RecursionProgram, D,
15};
16use tracing::Instrument;
17
18use crate::machine::RecursionAir;
19
20pub async fn run_recursion_test_machines(
22 program: RecursionProgram<SP1Field>,
23 witness: Vec<Block<SP1Field>>,
24) {
25 type A = RecursionAir<SP1Field, 3, 2>;
26
27 let mut executor =
28 Executor::<SP1Field, BinomialExtensionField<SP1Field, D>, SP1DiffusionMatrix>::new(
29 Arc::new(program.clone()),
30 inner_perm(),
31 );
32 executor.witness_stream = witness.into();
33 executor.run().unwrap();
34
35 let machine = A::compress_machine();
37 run_test_recursion(vec![executor.record.clone()], machine, program.clone()).await.unwrap();
38}
39
40pub async fn test_recursion_linear_program(instrs: Vec<Instruction<SP1Field>>) {
43 run_recursion_test_machines(linear_program(instrs).unwrap(), Vec::new()).await;
44}
45
46pub async fn run_test_recursion<const DEGREE: usize, const VAR_EVENTS_PER_ROW: usize>(
47 records: Vec<ExecutionRecord<SP1Field>>,
48 machine: Machine<SP1Field, RecursionAir<SP1Field, DEGREE, VAR_EVENTS_PER_ROW>>,
49 program: RecursionProgram<SP1Field>,
50) -> Result<
51 MachineProof<SP1GlobalContext, SP1PcsProofInner>,
52 MachineVerifierConfigError<SP1GlobalContext, SP1InnerPcs>,
53> {
54 let log_blowup = 1;
55 let num_queries = unique_decoding_queries(log_blowup);
56 let log_stacking_height = 22;
57 let max_log_row_count = 21;
58 let verifier = ShardVerifier::from_basefold_parameters(
59 FriConfig::new(log_blowup, num_queries, SP1_PROOF_OF_WORK_BITS),
60 log_stacking_height,
61 max_log_row_count,
62 machine,
63 );
64 let prover = simple_prover(verifier.clone());
65
66 let (pk, vk) = prover
67 .setup(Arc::new(program))
68 .instrument(tracing::debug_span!("setup").or_current())
69 .await;
70
71 let pk = unsafe { pk.into_inner() };
72 let mut shard_proofs = Vec::with_capacity(records.len());
73 for record in records {
74 let proof = prover.prove_shard(pk.clone(), record).await;
75 shard_proofs.push(proof);
76 }
77
78 assert!(shard_proofs.len() == 1);
79
80 let proof = MachineProof { shard_proofs };
81
82 let machine_verifier = MachineVerifier::new(verifier);
83 tracing::debug_span!("verify the proof").in_scope(|| machine_verifier.verify(&vk, &proof))?;
84 Ok(proof)
85}