Skip to main content

sp1_recursion_executor/
program.rs

1use crate::{analyzed::AnalyzedInstruction, shape::RecursionShape, *};
2use serde::{Deserialize, Serialize};
3use slop_algebra::Field;
4use sp1_hypercube::{air::MachineProgram, septic_digest::SepticDigest};
5use std::ops::{Deref, DerefMut};
6
7pub use basic_block::BasicBlock;
8pub use raw::RawProgram;
9pub use seq_block::SeqBlock;
10
11/// A well-formed recursion program. See [`Self::new_unchecked`] for guaranteed (safety) invariants.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13#[repr(transparent)]
14pub struct RecursionProgram<F>(RootProgram<F>);
15
16impl<F> RecursionProgram<F> {
17    /// # Safety
18    /// The given program must be well formed. This is defined as the following:
19    /// - reads are performed after writes, according to a "happens-before" relation; and
20    /// - an address is written to at most once.
21    ///
22    /// The "happens-before" relation is defined as follows:
23    /// - It is a strict partial order, meaning it is transitive, irreflexive, and asymmetric.
24    /// - Instructions in a `BasicBlock` are linearly ordered.
25    /// - `SeqBlock`s in a `RawProgram` are linearly ordered, meaning:
26    ///     - Each `SeqBlock` has a set of initial instructions `I` and final instructions `O`.
27    ///     - For `SeqBlock::Basic`:
28    ///         - `I` is the singleton consisting of the first instruction in the enclosed
29    ///           `BasicBlock`.
30    ///         - `O` is the singleton consisting of the last instruction in the enclosed
31    ///           `BasicBlock`.
32    ///     - For `SeqBlock::Parallel`:
33    ///         - `I` is the set of initial instructions `I` in the first `SeqBlock` of the enclosed
34    ///           `RawProgram`.
35    ///         - `O` is the set of final instructions in the last `SeqBlock` of the enclosed
36    ///           `RawProgram`.
37    ///     - For consecutive `SeqBlock`s, each element of the first one's `O` happens before the
38    ///       second one's `I`.
39    ///
40    /// - The last condition is the event count analysis is done correctly see [`crate::analyzed`].
41    pub unsafe fn new_unchecked(program: RootProgram<F>) -> Self {
42        Self(program)
43    }
44
45    pub fn into_inner(self) -> RootProgram<F> {
46        self.0
47    }
48}
49
50impl<F> Default for RecursionProgram<F> {
51    fn default() -> Self {
52        // SAFETY: An empty program is always well formed.
53        unsafe { Self::new_unchecked(RootProgram::default()) }
54    }
55}
56
57impl<F> Deref for RecursionProgram<F> {
58    type Target = RootProgram<F>;
59
60    fn deref(&self) -> &Self::Target {
61        &self.0
62    }
63}
64
65impl<F> DerefMut for RecursionProgram<F> {
66    fn deref_mut(&mut self) -> &mut Self::Target {
67        &mut self.0
68    }
69}
70
71impl<F: Field> MachineProgram<F> for RecursionProgram<F> {
72    fn pc_start(&self) -> [F; 3] {
73        [F::zero(), F::zero(), F::zero()]
74    }
75
76    fn initial_global_cumulative_sum(&self) -> SepticDigest<F> {
77        SepticDigest::<F>::zero()
78    }
79
80    fn enable_untrusted_programs(&self) -> F {
81        F::zero()
82    }
83}
84
85#[cfg(any(test, feature = "program_validation"))]
86pub use validation::*;
87
88#[cfg(any(test, feature = "program_validation"))]
89mod validation {
90    use super::*;
91
92    use std::{fmt::Debug, iter, mem};
93
94    use range_set_blaze::{MultiwayRangeSetBlazeRef, RangeSetBlaze};
95    use slop_algebra::PrimeField32;
96    use smallvec::{smallvec, SmallVec};
97    use thiserror::Error;
98
99    #[derive(Error, Debug)]
100    pub enum StructureError<F: Debug> {
101        #[error("tried to read from uninitialized address {addr:?}. instruction: {instr:?}")]
102        ReadFromUninit { addr: Address<F>, instr: Instruction<F> },
103    }
104
105    #[derive(Error, Debug)]
106    pub enum SummaryError {
107        #[error("`total_memory` is insufficient. configured: {configured}. required: {required}")]
108        OutOfMemory { configured: usize, required: usize },
109    }
110
111    #[derive(Error, Debug)]
112    pub enum ValidationError<F: Debug> {
113        Structure(#[from] StructureError<F>),
114        Summary(#[from] SummaryError),
115    }
116
117    impl<F: PrimeField32> RecursionProgram<F> {
118        /// Validate the program without modifying its summary metadata.
119        pub fn try_new_unmodified(
120            program: RootProgram<F>,
121        ) -> Result<Self, Box<ValidationError<F>>> {
122            let written_addrs = try_written_addrs(smallvec![], &program.inner)
123                .map_err(|e| ValidationError::from(*e))?;
124            if let Some(required) = written_addrs.last().map(|x| x as usize + 1) {
125                let configured = program.total_memory;
126                if required > configured {
127                    Err(Box::new(SummaryError::OutOfMemory { configured, required }.into()))?
128                }
129            }
130            // SAFETY: We just checked all the invariants.
131            Ok(unsafe { Self::new_unchecked(program) })
132        }
133
134        /// Validate the program, modifying summary metadata if necessary.
135        pub fn try_new(mut program: RootProgram<F>) -> Result<Self, Box<StructureError<F>>> {
136            let written_addrs = try_written_addrs(smallvec![], &program.inner)?;
137            program.total_memory = written_addrs.last().map(|x| x as usize + 1).unwrap_or_default();
138            // SAFETY: We just checked/enforced all the invariants.
139            Ok(unsafe { Self::new_unchecked(program) })
140        }
141    }
142
143    fn try_written_addrs<F: PrimeField32>(
144        readable_addrs: SmallVec<[&RangeSetBlaze<u32>; 3]>,
145        program: &RawProgram<AnalyzedInstruction<F>>,
146    ) -> Result<RangeSetBlaze<u32>, Box<StructureError<F>>> {
147        let mut written_addrs = RangeSetBlaze::<u32>::new();
148        for block in &program.seq_blocks {
149            match block {
150                SeqBlock::Basic(basic_block) => {
151                    for instr in &basic_block.instrs {
152                        let (inputs, outputs) = instr.inner.io_addrs();
153                        inputs.into_iter().try_for_each(|i| {
154                            let i_u32 = i.0.as_canonical_u32();
155                            iter::once(&written_addrs)
156                                .chain(readable_addrs.iter().copied())
157                                .any(|s| s.contains(i_u32))
158                                .then_some(())
159                                .ok_or_else(|| {
160                                    Box::new(StructureError::ReadFromUninit {
161                                        addr: i,
162                                        instr: instr.inner.clone(),
163                                    })
164                                })
165                        })?;
166                        written_addrs.extend(outputs.iter().map(|o| o.0.as_canonical_u32()));
167                    }
168                }
169                SeqBlock::Parallel(programs) => {
170                    let par_written_addrs = programs
171                        .iter()
172                        .map(|subprogram| {
173                            let sub_readable_addrs = iter::once(&written_addrs)
174                                .chain(readable_addrs.iter().copied())
175                                .collect();
176
177                            try_written_addrs(sub_readable_addrs, subprogram)
178                        })
179                        .collect::<Result<Vec<_>, _>>()?;
180                    written_addrs =
181                        iter::once(mem::take(&mut written_addrs)).chain(par_written_addrs).union();
182                }
183            }
184        }
185        Ok(written_addrs)
186    }
187
188    impl<F: PrimeField32> RootProgram<F> {
189        pub fn validate(self) -> Result<RecursionProgram<F>, Box<StructureError<F>>> {
190            RecursionProgram::try_new(self)
191        }
192    }
193
194    pub fn linear_program<F: PrimeField32>(
195        instrs: Vec<Instruction<F>>,
196    ) -> Result<RecursionProgram<F>, Box<StructureError<F>>> {
197        let (analyzed, counts) =
198            RawProgram { seq_blocks: vec![SeqBlock::Basic(BasicBlock { instrs })] }.analyze();
199
200        RootProgram { inner: analyzed, total_memory: 0, shape: None, event_counts: counts }
201            .validate()
202    }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct RootProgram<F> {
207    pub inner: raw::RawProgram<AnalyzedInstruction<F>>,
208    pub total_memory: usize,
209    pub shape: Option<RecursionShape<F>>,
210    pub event_counts: RecursionAirEventCount,
211}
212
213// `Default` without bounds on the type parameter.
214impl<F> Default for RootProgram<F> {
215    fn default() -> Self {
216        Self {
217            inner: Default::default(),
218            total_memory: Default::default(),
219            shape: None,
220            event_counts: Default::default(),
221        }
222    }
223}
224
225pub mod raw {
226    use std::iter::Flatten;
227
228    use super::*;
229
230    #[derive(Debug, Clone, Serialize, Deserialize)]
231    pub struct RawProgram<T> {
232        pub seq_blocks: Vec<SeqBlock<T>>,
233    }
234
235    // `Default` without bounds on the type parameter.
236    impl<T> Default for RawProgram<T> {
237        fn default() -> Self {
238            Self { seq_blocks: Default::default() }
239        }
240    }
241
242    impl<T> RawProgram<T> {
243        pub fn iter(&self) -> impl Iterator<Item = &'_ T> {
244            self.seq_blocks.iter().flatten()
245        }
246        pub fn iter_mut(&mut self) -> impl Iterator<Item = &'_ mut T> {
247            self.seq_blocks.iter_mut().flatten()
248        }
249    }
250
251    impl<T> IntoIterator for RawProgram<T> {
252        type Item = T;
253
254        type IntoIter = Flatten<<Vec<SeqBlock<T>> as IntoIterator>::IntoIter>;
255
256        fn into_iter(self) -> Self::IntoIter {
257            self.seq_blocks.into_iter().flatten()
258        }
259    }
260
261    impl<'a, T> IntoIterator for &'a RawProgram<T> {
262        type Item = &'a T;
263
264        type IntoIter = Flatten<<&'a Vec<SeqBlock<T>> as IntoIterator>::IntoIter>;
265
266        fn into_iter(self) -> Self::IntoIter {
267            self.seq_blocks.iter().flatten()
268        }
269    }
270
271    impl<'a, T> IntoIterator for &'a mut RawProgram<T> {
272        type Item = &'a mut T;
273
274        type IntoIter = Flatten<<&'a mut Vec<SeqBlock<T>> as IntoIterator>::IntoIter>;
275
276        fn into_iter(self) -> Self::IntoIter {
277            self.seq_blocks.iter_mut().flatten()
278        }
279    }
280}
281
282pub mod seq_block {
283    use std::iter::Flatten;
284
285    use super::*;
286
287    /// Segments that may be sequentially composed.
288    #[derive(Debug, Clone, Serialize, Deserialize)]
289    pub enum SeqBlock<T> {
290        /// One basic block.
291        Basic(BasicBlock<T>),
292        /// Many blocks to be run in parallel.
293        Parallel(Vec<RawProgram<T>>),
294    }
295
296    impl<T> SeqBlock<T> {
297        pub fn iter(&self) -> Iter<'_, T> {
298            self.into_iter()
299        }
300
301        pub fn iter_mut(&mut self) -> IterMut<'_, T> {
302            self.into_iter()
303        }
304    }
305
306    // Bunch of iterator boilerplate.
307    #[derive(Debug)]
308    pub enum Iter<'a, T> {
309        Basic(<&'a Vec<T> as IntoIterator>::IntoIter),
310        Parallel(Box<Flatten<<&'a Vec<RawProgram<T>> as IntoIterator>::IntoIter>>),
311    }
312
313    impl<'a, T> Iterator for Iter<'a, T> {
314        type Item = &'a T;
315
316        fn next(&mut self) -> Option<Self::Item> {
317            match self {
318                Iter::Basic(it) => it.next(),
319                Iter::Parallel(it) => it.next(),
320            }
321        }
322    }
323
324    impl<'a, T> IntoIterator for &'a SeqBlock<T> {
325        type Item = &'a T;
326
327        type IntoIter = Iter<'a, T>;
328
329        fn into_iter(self) -> Self::IntoIter {
330            match self {
331                SeqBlock::Basic(basic_block) => Iter::Basic(basic_block.instrs.iter()),
332                SeqBlock::Parallel(vec) => Iter::Parallel(Box::new(vec.iter().flatten())),
333            }
334        }
335    }
336
337    #[derive(Debug)]
338    pub enum IterMut<'a, T> {
339        Basic(<&'a mut Vec<T> as IntoIterator>::IntoIter),
340        Parallel(Box<Flatten<<&'a mut Vec<RawProgram<T>> as IntoIterator>::IntoIter>>),
341    }
342
343    impl<'a, T> Iterator for IterMut<'a, T> {
344        type Item = &'a mut T;
345
346        fn next(&mut self) -> Option<Self::Item> {
347            match self {
348                IterMut::Basic(it) => it.next(),
349                IterMut::Parallel(it) => it.next(),
350            }
351        }
352    }
353
354    impl<'a, T> IntoIterator for &'a mut SeqBlock<T> {
355        type Item = &'a mut T;
356
357        type IntoIter = IterMut<'a, T>;
358
359        fn into_iter(self) -> Self::IntoIter {
360            match self {
361                SeqBlock::Basic(basic_block) => IterMut::Basic(basic_block.instrs.iter_mut()),
362                SeqBlock::Parallel(vec) => IterMut::Parallel(Box::new(vec.iter_mut().flatten())),
363            }
364        }
365    }
366
367    #[derive(Debug, Clone)]
368    pub enum IntoIter<T> {
369        Basic(<Vec<T> as IntoIterator>::IntoIter),
370        Parallel(Box<Flatten<<Vec<RawProgram<T>> as IntoIterator>::IntoIter>>),
371    }
372
373    impl<T> Iterator for IntoIter<T> {
374        type Item = T;
375
376        fn next(&mut self) -> Option<Self::Item> {
377            match self {
378                IntoIter::Basic(it) => it.next(),
379                IntoIter::Parallel(it) => it.next(),
380            }
381        }
382    }
383
384    impl<T> IntoIterator for SeqBlock<T> {
385        type Item = T;
386
387        type IntoIter = IntoIter<T>;
388
389        fn into_iter(self) -> Self::IntoIter {
390            match self {
391                SeqBlock::Basic(basic_block) => IntoIter::Basic(basic_block.instrs.into_iter()),
392                SeqBlock::Parallel(vec) => IntoIter::Parallel(Box::new(vec.into_iter().flatten())),
393            }
394        }
395    }
396}
397
398pub mod basic_block {
399    use super::*;
400
401    #[derive(Debug, Clone, Serialize, Deserialize)]
402    pub struct BasicBlock<T> {
403        pub instrs: Vec<T>,
404    }
405
406    // Less restrictive trait bounds.
407    impl<T> Default for BasicBlock<T> {
408        fn default() -> Self {
409            Self { instrs: Default::default() }
410        }
411    }
412}