sp1_recursion_core/runtime/
program.rs

1use crate::*;
2use p3_field::Field;
3use serde::{Deserialize, Serialize};
4use shape::RecursionShape;
5use sp1_stark::air::{MachineAir, MachineProgram};
6use sp1_stark::septic_digest::SepticDigest;
7use std::ops::Deref;
8
9pub use basic_block::BasicBlock;
10pub use raw::RawProgram;
11pub use seq_block::SeqBlock;
12
13/// A well-formed recursion program. See [`Self::new_unchecked`] for guaranteed (safety) invariants.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15#[repr(transparent)]
16pub struct RecursionProgram<F>(RootProgram<F>);
17
18impl<F> RecursionProgram<F> {
19    /// # Safety
20    /// The given program must be well formed. This is defined as the following:
21    /// - reads are performed after writes, according to a "happens-before" relation; and
22    /// - an address is written to at most once.
23    ///
24    /// The "happens-before" relation is defined as follows:
25    /// - It is a strict partial order, meaning it is transitive, irreflexive, and asymmetric.
26    /// - Instructions in a `BasicBlock` are linearly ordered.
27    /// - `SeqBlock`s in a `RawProgram` are linearly ordered, meaning:
28    ///     - Each `SeqBlock` has a set of initial instructions `I` and final instructions `O`.
29    ///     - For `SeqBlock::Basic`:
30    ///         - `I` is the singleton consisting of the first instruction in the enclosed `BasicBlock`.
31    ///         - `O` is the singleton consisting of the last instruction in the enclosed `BasicBlock`.
32    ///     - For `SeqBlock::Parallel`:
33    ///         - `I` is the set of initial instructions `I` in the first `SeqBlock` of the enclosed `RawProgram`.
34    ///         - `O` is the set of final instructions in the last `SeqBlock` of the enclosed `RawProgram`.
35    ///     - For consecutive `SeqBlock`s, each element of the first one's `O` happens before the second one's `I`.
36    pub unsafe fn new_unchecked(program: RootProgram<F>) -> Self {
37        Self(program)
38    }
39
40    pub fn into_inner(self) -> RootProgram<F> {
41        self.0
42    }
43
44    pub fn shape_mut(&mut self) -> &mut Option<RecursionShape> {
45        &mut self.0.shape
46    }
47}
48
49impl<F> Default for RecursionProgram<F> {
50    fn default() -> Self {
51        // SAFETY: An empty program is always well formed.
52        unsafe { Self::new_unchecked(RootProgram::default()) }
53    }
54}
55
56impl<F> Deref for RecursionProgram<F> {
57    type Target = RootProgram<F>;
58
59    fn deref(&self) -> &Self::Target {
60        &self.0
61    }
62}
63
64impl<F: Field> MachineProgram<F> for RecursionProgram<F> {
65    fn pc_start(&self) -> F {
66        F::zero()
67    }
68
69    fn initial_global_cumulative_sum(&self) -> SepticDigest<F> {
70        SepticDigest::<F>::zero()
71    }
72}
73
74impl<F: Field> RecursionProgram<F> {
75    #[inline]
76    pub fn fixed_log2_rows<A: MachineAir<F>>(&self, air: &A) -> Option<usize> {
77        self.0
78            .shape
79            .as_ref()
80            .map(|shape| {
81                shape
82                    .inner
83                    .get(&air.name())
84                    .unwrap_or_else(|| panic!("Chip {} not found in specified shape", air.name()))
85            })
86            .copied()
87    }
88}
89
90#[cfg(any(test, feature = "program_validation"))]
91pub use validation::*;
92
93#[cfg(any(test, feature = "program_validation"))]
94mod validation {
95    use super::*;
96
97    use std::{fmt::Debug, iter, mem};
98
99    use p3_field::PrimeField32;
100    use range_set_blaze::{MultiwayRangeSetBlazeRef, RangeSetBlaze};
101    use smallvec::{smallvec, SmallVec};
102    use thiserror::Error;
103
104    #[derive(Error, Debug)]
105    pub enum StructureError<F: Debug> {
106        #[error("tried to read from uninitialized address {addr:?}. instruction: {instr:?}")]
107        ReadFromUninit { addr: Address<F>, instr: Instruction<F> },
108    }
109
110    #[derive(Error, Debug)]
111    pub enum SummaryError {
112        #[error("`total_memory` is insufficient. configured: {configured}. required: {required}")]
113        OutOfMemory { configured: usize, required: usize },
114    }
115
116    #[derive(Error, Debug)]
117    pub enum ValidationError<F: Debug> {
118        Structure(#[from] StructureError<F>),
119        Summary(#[from] SummaryError),
120    }
121
122    impl<F: PrimeField32> RecursionProgram<F> {
123        /// Validate the program without modifying its summary metadata.
124        pub fn try_new_unmodified(
125            program: RootProgram<F>,
126        ) -> Result<Self, Box<ValidationError<F>>> {
127            let written_addrs = try_written_addrs(smallvec![], &program.inner)
128                .map_err(|e| ValidationError::from(*e))?;
129            if let Some(required) = written_addrs.last().map(|x| x as usize + 1) {
130                let configured = program.total_memory;
131                if required > configured {
132                    Err(Box::new(SummaryError::OutOfMemory { configured, required }.into()))?
133                }
134            }
135            // SAFETY: We just checked all the invariants.
136            Ok(unsafe { Self::new_unchecked(program) })
137        }
138
139        /// Validate the program, modifying summary metadata if necessary.
140        pub fn try_new(mut program: RootProgram<F>) -> Result<Self, Box<StructureError<F>>> {
141            let written_addrs = try_written_addrs(smallvec![], &program.inner)?;
142            program.total_memory = written_addrs.last().map(|x| x as usize + 1).unwrap_or_default();
143            // SAFETY: We just checked/enforced all the invariants.
144            Ok(unsafe { Self::new_unchecked(program) })
145        }
146    }
147
148    fn try_written_addrs<F: PrimeField32>(
149        readable_addrs: SmallVec<[&RangeSetBlaze<u32>; 3]>,
150        program: &RawProgram<Instruction<F>>,
151    ) -> Result<RangeSetBlaze<u32>, Box<StructureError<F>>> {
152        let mut written_addrs = RangeSetBlaze::<u32>::new();
153        for block in &program.seq_blocks {
154            match block {
155                SeqBlock::Basic(basic_block) => {
156                    for instr in &basic_block.instrs {
157                        let (inputs, outputs) = instr.io_addrs();
158                        inputs.into_iter().try_for_each(|i| {
159                            let i_u32 = i.0.as_canonical_u32();
160                            iter::once(&written_addrs)
161                                .chain(readable_addrs.iter().copied())
162                                .any(|s| s.contains(i_u32))
163                                .then_some(())
164                                .ok_or_else(|| {
165                                    Box::new(StructureError::ReadFromUninit {
166                                        addr: i,
167                                        instr: instr.clone(),
168                                    })
169                                })
170                        })?;
171                        written_addrs.extend(outputs.iter().map(|o| o.0.as_canonical_u32()));
172                    }
173                }
174                SeqBlock::Parallel(programs) => {
175                    let par_written_addrs = programs
176                        .iter()
177                        .map(|subprogram| {
178                            let sub_readable_addrs = iter::once(&written_addrs)
179                                .chain(readable_addrs.iter().copied())
180                                .collect();
181
182                            try_written_addrs(sub_readable_addrs, subprogram)
183                        })
184                        .collect::<Result<Vec<_>, _>>()?;
185                    written_addrs =
186                        iter::once(mem::take(&mut written_addrs)).chain(par_written_addrs).union();
187                }
188            }
189        }
190        Ok(written_addrs)
191    }
192
193    impl<F: PrimeField32> RootProgram<F> {
194        pub fn validate(self) -> Result<RecursionProgram<F>, Box<StructureError<F>>> {
195            RecursionProgram::try_new(self)
196        }
197    }
198
199    #[cfg(test)]
200    pub fn linear_program<F: PrimeField32>(
201        instrs: Vec<Instruction<F>>,
202    ) -> Result<RecursionProgram<F>, Box<StructureError<F>>> {
203        RootProgram {
204            inner: RawProgram { seq_blocks: vec![SeqBlock::Basic(BasicBlock { instrs })] },
205            total_memory: 0, // Will be filled in.
206            shape: None,
207        }
208        .validate()
209    }
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct RootProgram<F> {
214    pub inner: raw::RawProgram<Instruction<F>>,
215    pub total_memory: usize,
216    pub shape: Option<RecursionShape>,
217}
218
219// `Default` without bounds on the type parameter.
220impl<F> Default for RootProgram<F> {
221    fn default() -> Self {
222        Self {
223            inner: Default::default(),
224            total_memory: Default::default(),
225            shape: Default::default(),
226        }
227    }
228}
229
230pub mod raw {
231    use std::iter::Flatten;
232
233    use super::*;
234
235    #[derive(Debug, Clone, Serialize, Deserialize)]
236    pub struct RawProgram<T> {
237        pub seq_blocks: Vec<SeqBlock<T>>,
238    }
239
240    // `Default` without bounds on the type parameter.
241    impl<T> Default for RawProgram<T> {
242        fn default() -> Self {
243            Self { seq_blocks: Default::default() }
244        }
245    }
246
247    impl<T> RawProgram<T> {
248        pub fn iter(&self) -> impl Iterator<Item = &'_ T> {
249            self.seq_blocks.iter().flatten()
250        }
251        pub fn iter_mut(&mut self) -> impl Iterator<Item = &'_ mut T> {
252            self.seq_blocks.iter_mut().flatten()
253        }
254    }
255
256    impl<T> IntoIterator for RawProgram<T> {
257        type Item = T;
258
259        type IntoIter = Flatten<<Vec<SeqBlock<T>> as IntoIterator>::IntoIter>;
260
261        fn into_iter(self) -> Self::IntoIter {
262            self.seq_blocks.into_iter().flatten()
263        }
264    }
265
266    impl<'a, T> IntoIterator for &'a RawProgram<T> {
267        type Item = &'a T;
268
269        type IntoIter = Flatten<<&'a Vec<SeqBlock<T>> as IntoIterator>::IntoIter>;
270
271        fn into_iter(self) -> Self::IntoIter {
272            self.seq_blocks.iter().flatten()
273        }
274    }
275
276    impl<'a, T> IntoIterator for &'a mut RawProgram<T> {
277        type Item = &'a mut T;
278
279        type IntoIter = Flatten<<&'a mut Vec<SeqBlock<T>> as IntoIterator>::IntoIter>;
280
281        fn into_iter(self) -> Self::IntoIter {
282            self.seq_blocks.iter_mut().flatten()
283        }
284    }
285}
286
287pub mod seq_block {
288    use std::iter::Flatten;
289
290    use super::*;
291
292    /// Segments that may be sequentially composed.
293    #[derive(Debug, Clone, Serialize, Deserialize)]
294    pub enum SeqBlock<T> {
295        /// One basic block.
296        Basic(BasicBlock<T>),
297        /// Many blocks to be run in parallel.
298        Parallel(Vec<RawProgram<T>>),
299    }
300
301    impl<T> SeqBlock<T> {
302        pub fn iter(&self) -> Iter<'_, T> {
303            self.into_iter()
304        }
305
306        pub fn iter_mut(&mut self) -> IterMut<'_, T> {
307            self.into_iter()
308        }
309    }
310
311    // Bunch of iterator boilerplate.
312    #[derive(Debug)]
313    pub enum Iter<'a, T> {
314        Basic(<&'a Vec<T> as IntoIterator>::IntoIter),
315        Parallel(Box<Flatten<<&'a Vec<RawProgram<T>> as IntoIterator>::IntoIter>>),
316    }
317
318    impl<'a, T> Iterator for Iter<'a, T> {
319        type Item = &'a T;
320
321        fn next(&mut self) -> Option<Self::Item> {
322            match self {
323                Iter::Basic(it) => it.next(),
324                Iter::Parallel(it) => it.next(),
325            }
326        }
327    }
328
329    impl<'a, T> IntoIterator for &'a SeqBlock<T> {
330        type Item = &'a T;
331
332        type IntoIter = Iter<'a, T>;
333
334        fn into_iter(self) -> Self::IntoIter {
335            match self {
336                SeqBlock::Basic(basic_block) => Iter::Basic(basic_block.instrs.iter()),
337                SeqBlock::Parallel(vec) => Iter::Parallel(Box::new(vec.iter().flatten())),
338            }
339        }
340    }
341
342    #[derive(Debug)]
343    pub enum IterMut<'a, T> {
344        Basic(<&'a mut Vec<T> as IntoIterator>::IntoIter),
345        Parallel(Box<Flatten<<&'a mut Vec<RawProgram<T>> as IntoIterator>::IntoIter>>),
346    }
347
348    impl<'a, T> Iterator for IterMut<'a, T> {
349        type Item = &'a mut T;
350
351        fn next(&mut self) -> Option<Self::Item> {
352            match self {
353                IterMut::Basic(it) => it.next(),
354                IterMut::Parallel(it) => it.next(),
355            }
356        }
357    }
358
359    impl<'a, T> IntoIterator for &'a mut SeqBlock<T> {
360        type Item = &'a mut T;
361
362        type IntoIter = IterMut<'a, T>;
363
364        fn into_iter(self) -> Self::IntoIter {
365            match self {
366                SeqBlock::Basic(basic_block) => IterMut::Basic(basic_block.instrs.iter_mut()),
367                SeqBlock::Parallel(vec) => IterMut::Parallel(Box::new(vec.iter_mut().flatten())),
368            }
369        }
370    }
371
372    #[derive(Debug, Clone)]
373    pub enum IntoIter<T> {
374        Basic(<Vec<T> as IntoIterator>::IntoIter),
375        Parallel(Box<Flatten<<Vec<RawProgram<T>> as IntoIterator>::IntoIter>>),
376    }
377
378    impl<T> Iterator for IntoIter<T> {
379        type Item = T;
380
381        fn next(&mut self) -> Option<Self::Item> {
382            match self {
383                IntoIter::Basic(it) => it.next(),
384                IntoIter::Parallel(it) => it.next(),
385            }
386        }
387    }
388
389    impl<T> IntoIterator for SeqBlock<T> {
390        type Item = T;
391
392        type IntoIter = IntoIter<T>;
393
394        fn into_iter(self) -> Self::IntoIter {
395            match self {
396                SeqBlock::Basic(basic_block) => IntoIter::Basic(basic_block.instrs.into_iter()),
397                SeqBlock::Parallel(vec) => IntoIter::Parallel(Box::new(vec.into_iter().flatten())),
398            }
399        }
400    }
401}
402
403pub mod basic_block {
404    use super::*;
405
406    #[derive(Debug, Clone, Serialize, Deserialize)]
407    pub struct BasicBlock<T> {
408        pub instrs: Vec<T>,
409    }
410
411    // Less restrictive trait bounds.
412    impl<T> Default for BasicBlock<T> {
413        fn default() -> Self {
414            Self { instrs: Default::default() }
415        }
416    }
417}