sp1_recursion_core/runtime/
program.rs

1use crate::*;
2use p3_field::Field;
3use serde::{Deserialize, Serialize};
4use shape::RecursionShape;
5use sp1_stark::{
6    air::{MachineAir, MachineProgram},
7    septic_digest::SepticDigest,
8};
9use std::ops::Deref;
10
11pub use basic_block::BasicBlock;
12pub use raw::RawProgram;
13pub use seq_block::SeqBlock;
14
15/// A well-formed recursion program. See [`Self::new_unchecked`] for guaranteed (safety) invariants.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17#[repr(transparent)]
18pub struct RecursionProgram<F>(RootProgram<F>);
19
20impl<F> RecursionProgram<F> {
21    /// # Safety
22    /// The given program must be well formed. This is defined as the following:
23    /// - reads are performed after writes, according to a "happens-before" relation; and
24    /// - an address is written to at most once.
25    ///
26    /// The "happens-before" relation is defined as follows:
27    /// - It is a strict partial order, meaning it is transitive, irreflexive, and asymmetric.
28    /// - Instructions in a `BasicBlock` are linearly ordered.
29    /// - `SeqBlock`s in a `RawProgram` are linearly ordered, meaning:
30    ///     - Each `SeqBlock` has a set of initial instructions `I` and final instructions `O`.
31    ///     - For `SeqBlock::Basic`:
32    ///         - `I` is the singleton consisting of the first instruction in the enclosed
33    ///           `BasicBlock`.
34    ///         - `O` is the singleton consisting of the last instruction in the enclosed
35    ///           `BasicBlock`.
36    ///     - For `SeqBlock::Parallel`:
37    ///         - `I` is the set of initial instructions `I` in the first `SeqBlock` of the enclosed
38    ///           `RawProgram`.
39    ///         - `O` is the set of final instructions in the last `SeqBlock` of the enclosed
40    ///           `RawProgram`.
41    ///     - For consecutive `SeqBlock`s, each element of the first one's `O` happens before the
42    ///       second one's `I`.
43    pub unsafe fn new_unchecked(program: RootProgram<F>) -> Self {
44        Self(program)
45    }
46
47    pub fn into_inner(self) -> RootProgram<F> {
48        self.0
49    }
50
51    pub fn shape_mut(&mut self) -> &mut Option<RecursionShape> {
52        &mut self.0.shape
53    }
54}
55
56impl<F> Default for RecursionProgram<F> {
57    fn default() -> Self {
58        // SAFETY: An empty program is always well formed.
59        unsafe { Self::new_unchecked(RootProgram::default()) }
60    }
61}
62
63impl<F> Deref for RecursionProgram<F> {
64    type Target = RootProgram<F>;
65
66    fn deref(&self) -> &Self::Target {
67        &self.0
68    }
69}
70
71impl<F: Field> MachineProgram<F> for RecursionProgram<F> {
72    fn pc_start(&self) -> F {
73        F::zero()
74    }
75
76    fn initial_global_cumulative_sum(&self) -> SepticDigest<F> {
77        SepticDigest::<F>::zero()
78    }
79}
80
81impl<F: Field> RecursionProgram<F> {
82    #[inline]
83    pub fn fixed_log2_rows<A: MachineAir<F>>(&self, air: &A) -> Option<usize> {
84        self.0
85            .shape
86            .as_ref()
87            .map(|shape| {
88                shape
89                    .inner
90                    .get(&air.name())
91                    .unwrap_or_else(|| panic!("Chip {} not found in specified shape", air.name()))
92            })
93            .copied()
94    }
95}
96
97#[cfg(any(test, feature = "program_validation"))]
98pub use validation::*;
99
100#[cfg(any(test, feature = "program_validation"))]
101mod validation {
102    use super::*;
103
104    use std::{fmt::Debug, iter, mem};
105
106    use p3_field::PrimeField32;
107    use range_set_blaze::{MultiwayRangeSetBlazeRef, RangeSetBlaze};
108    use smallvec::{smallvec, SmallVec};
109    use thiserror::Error;
110
111    #[derive(Error, Debug)]
112    pub enum StructureError<F: Debug> {
113        #[error("tried to read from uninitialized address {addr:?}. instruction: {instr:?}")]
114        ReadFromUninit { addr: Address<F>, instr: Instruction<F> },
115    }
116
117    #[derive(Error, Debug)]
118    pub enum SummaryError {
119        #[error("`total_memory` is insufficient. configured: {configured}. required: {required}")]
120        OutOfMemory { configured: usize, required: usize },
121    }
122
123    #[derive(Error, Debug)]
124    pub enum ValidationError<F: Debug> {
125        Structure(#[from] StructureError<F>),
126        Summary(#[from] SummaryError),
127    }
128
129    impl<F: PrimeField32> RecursionProgram<F> {
130        /// Validate the program without modifying its summary metadata.
131        pub fn try_new_unmodified(
132            program: RootProgram<F>,
133        ) -> Result<Self, Box<ValidationError<F>>> {
134            let written_addrs = try_written_addrs(smallvec![], &program.inner)
135                .map_err(|e| ValidationError::from(*e))?;
136            if let Some(required) = written_addrs.last().map(|x| x as usize + 1) {
137                let configured = program.total_memory;
138                if required > configured {
139                    Err(Box::new(SummaryError::OutOfMemory { configured, required }.into()))?
140                }
141            }
142            // SAFETY: We just checked all the invariants.
143            Ok(unsafe { Self::new_unchecked(program) })
144        }
145
146        /// Validate the program, modifying summary metadata if necessary.
147        pub fn try_new(mut program: RootProgram<F>) -> Result<Self, Box<StructureError<F>>> {
148            let written_addrs = try_written_addrs(smallvec![], &program.inner)?;
149            program.total_memory = written_addrs.last().map(|x| x as usize + 1).unwrap_or_default();
150            // SAFETY: We just checked/enforced all the invariants.
151            Ok(unsafe { Self::new_unchecked(program) })
152        }
153    }
154
155    fn try_written_addrs<F: PrimeField32>(
156        readable_addrs: SmallVec<[&RangeSetBlaze<u32>; 3]>,
157        program: &RawProgram<Instruction<F>>,
158    ) -> Result<RangeSetBlaze<u32>, Box<StructureError<F>>> {
159        let mut written_addrs = RangeSetBlaze::<u32>::new();
160        for block in &program.seq_blocks {
161            match block {
162                SeqBlock::Basic(basic_block) => {
163                    for instr in &basic_block.instrs {
164                        let (inputs, outputs) = instr.io_addrs();
165                        inputs.into_iter().try_for_each(|i| {
166                            let i_u32 = i.0.as_canonical_u32();
167                            iter::once(&written_addrs)
168                                .chain(readable_addrs.iter().copied())
169                                .any(|s| s.contains(i_u32))
170                                .then_some(())
171                                .ok_or_else(|| {
172                                    Box::new(StructureError::ReadFromUninit {
173                                        addr: i,
174                                        instr: instr.clone(),
175                                    })
176                                })
177                        })?;
178                        written_addrs.extend(outputs.iter().map(|o| o.0.as_canonical_u32()));
179                    }
180                }
181                SeqBlock::Parallel(programs) => {
182                    let par_written_addrs = programs
183                        .iter()
184                        .map(|subprogram| {
185                            let sub_readable_addrs = iter::once(&written_addrs)
186                                .chain(readable_addrs.iter().copied())
187                                .collect();
188
189                            try_written_addrs(sub_readable_addrs, subprogram)
190                        })
191                        .collect::<Result<Vec<_>, _>>()?;
192                    written_addrs =
193                        iter::once(mem::take(&mut written_addrs)).chain(par_written_addrs).union();
194                }
195            }
196        }
197        Ok(written_addrs)
198    }
199
200    impl<F: PrimeField32> RootProgram<F> {
201        pub fn validate(self) -> Result<RecursionProgram<F>, Box<StructureError<F>>> {
202            RecursionProgram::try_new(self)
203        }
204    }
205
206    #[cfg(test)]
207    pub fn linear_program<F: PrimeField32>(
208        instrs: Vec<Instruction<F>>,
209    ) -> Result<RecursionProgram<F>, Box<StructureError<F>>> {
210        RootProgram {
211            inner: RawProgram { seq_blocks: vec![SeqBlock::Basic(BasicBlock { instrs })] },
212            total_memory: 0, // Will be filled in.
213            shape: None,
214        }
215        .validate()
216    }
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct RootProgram<F> {
221    pub inner: raw::RawProgram<Instruction<F>>,
222    pub total_memory: usize,
223    pub shape: Option<RecursionShape>,
224}
225
226// `Default` without bounds on the type parameter.
227impl<F> Default for RootProgram<F> {
228    fn default() -> Self {
229        Self {
230            inner: Default::default(),
231            total_memory: Default::default(),
232            shape: Default::default(),
233        }
234    }
235}
236
237pub mod raw {
238    use std::iter::Flatten;
239
240    use super::*;
241
242    #[derive(Debug, Clone, Serialize, Deserialize)]
243    pub struct RawProgram<T> {
244        pub seq_blocks: Vec<SeqBlock<T>>,
245    }
246
247    // `Default` without bounds on the type parameter.
248    impl<T> Default for RawProgram<T> {
249        fn default() -> Self {
250            Self { seq_blocks: Default::default() }
251        }
252    }
253
254    impl<T> RawProgram<T> {
255        pub fn iter(&self) -> impl Iterator<Item = &'_ T> {
256            self.seq_blocks.iter().flatten()
257        }
258        pub fn iter_mut(&mut self) -> impl Iterator<Item = &'_ mut T> {
259            self.seq_blocks.iter_mut().flatten()
260        }
261    }
262
263    impl<T> IntoIterator for RawProgram<T> {
264        type Item = T;
265
266        type IntoIter = Flatten<<Vec<SeqBlock<T>> as IntoIterator>::IntoIter>;
267
268        fn into_iter(self) -> Self::IntoIter {
269            self.seq_blocks.into_iter().flatten()
270        }
271    }
272
273    impl<'a, T> IntoIterator for &'a RawProgram<T> {
274        type Item = &'a T;
275
276        type IntoIter = Flatten<<&'a Vec<SeqBlock<T>> as IntoIterator>::IntoIter>;
277
278        fn into_iter(self) -> Self::IntoIter {
279            self.seq_blocks.iter().flatten()
280        }
281    }
282
283    impl<'a, T> IntoIterator for &'a mut RawProgram<T> {
284        type Item = &'a mut T;
285
286        type IntoIter = Flatten<<&'a mut Vec<SeqBlock<T>> as IntoIterator>::IntoIter>;
287
288        fn into_iter(self) -> Self::IntoIter {
289            self.seq_blocks.iter_mut().flatten()
290        }
291    }
292}
293
294pub mod seq_block {
295    use std::iter::Flatten;
296
297    use super::*;
298
299    /// Segments that may be sequentially composed.
300    #[derive(Debug, Clone, Serialize, Deserialize)]
301    pub enum SeqBlock<T> {
302        /// One basic block.
303        Basic(BasicBlock<T>),
304        /// Many blocks to be run in parallel.
305        Parallel(Vec<RawProgram<T>>),
306    }
307
308    impl<T> SeqBlock<T> {
309        pub fn iter(&self) -> Iter<'_, T> {
310            self.into_iter()
311        }
312
313        pub fn iter_mut(&mut self) -> IterMut<'_, T> {
314            self.into_iter()
315        }
316    }
317
318    // Bunch of iterator boilerplate.
319    #[derive(Debug)]
320    pub enum Iter<'a, T> {
321        Basic(<&'a Vec<T> as IntoIterator>::IntoIter),
322        Parallel(Box<Flatten<<&'a Vec<RawProgram<T>> as IntoIterator>::IntoIter>>),
323    }
324
325    impl<'a, T> Iterator for Iter<'a, T> {
326        type Item = &'a T;
327
328        fn next(&mut self) -> Option<Self::Item> {
329            match self {
330                Iter::Basic(it) => it.next(),
331                Iter::Parallel(it) => it.next(),
332            }
333        }
334    }
335
336    impl<'a, T> IntoIterator for &'a SeqBlock<T> {
337        type Item = &'a T;
338
339        type IntoIter = Iter<'a, T>;
340
341        fn into_iter(self) -> Self::IntoIter {
342            match self {
343                SeqBlock::Basic(basic_block) => Iter::Basic(basic_block.instrs.iter()),
344                SeqBlock::Parallel(vec) => Iter::Parallel(Box::new(vec.iter().flatten())),
345            }
346        }
347    }
348
349    #[derive(Debug)]
350    pub enum IterMut<'a, T> {
351        Basic(<&'a mut Vec<T> as IntoIterator>::IntoIter),
352        Parallel(Box<Flatten<<&'a mut Vec<RawProgram<T>> as IntoIterator>::IntoIter>>),
353    }
354
355    impl<'a, T> Iterator for IterMut<'a, T> {
356        type Item = &'a mut T;
357
358        fn next(&mut self) -> Option<Self::Item> {
359            match self {
360                IterMut::Basic(it) => it.next(),
361                IterMut::Parallel(it) => it.next(),
362            }
363        }
364    }
365
366    impl<'a, T> IntoIterator for &'a mut SeqBlock<T> {
367        type Item = &'a mut T;
368
369        type IntoIter = IterMut<'a, T>;
370
371        fn into_iter(self) -> Self::IntoIter {
372            match self {
373                SeqBlock::Basic(basic_block) => IterMut::Basic(basic_block.instrs.iter_mut()),
374                SeqBlock::Parallel(vec) => IterMut::Parallel(Box::new(vec.iter_mut().flatten())),
375            }
376        }
377    }
378
379    #[derive(Debug, Clone)]
380    pub enum IntoIter<T> {
381        Basic(<Vec<T> as IntoIterator>::IntoIter),
382        Parallel(Box<Flatten<<Vec<RawProgram<T>> as IntoIterator>::IntoIter>>),
383    }
384
385    impl<T> Iterator for IntoIter<T> {
386        type Item = T;
387
388        fn next(&mut self) -> Option<Self::Item> {
389            match self {
390                IntoIter::Basic(it) => it.next(),
391                IntoIter::Parallel(it) => it.next(),
392            }
393        }
394    }
395
396    impl<T> IntoIterator for SeqBlock<T> {
397        type Item = T;
398
399        type IntoIter = IntoIter<T>;
400
401        fn into_iter(self) -> Self::IntoIter {
402            match self {
403                SeqBlock::Basic(basic_block) => IntoIter::Basic(basic_block.instrs.into_iter()),
404                SeqBlock::Parallel(vec) => IntoIter::Parallel(Box::new(vec.into_iter().flatten())),
405            }
406        }
407    }
408}
409
410pub mod basic_block {
411    use super::*;
412
413    #[derive(Debug, Clone, Serialize, Deserialize)]
414    pub struct BasicBlock<T> {
415        pub instrs: Vec<T>,
416    }
417
418    // Less restrictive trait bounds.
419    impl<T> Default for BasicBlock<T> {
420        fn default() -> Self {
421            Self { instrs: Default::default() }
422        }
423    }
424}