sp1_recursion_core/runtime/
program.rs1use crate::*;
2use p3_field::Field;
3use serde::{Deserialize, Serialize};
4use shape::RecursionShape;
5use sp1_stark::{
6 air::{MachineAir, MachineProgram},
7 septic_digest::SepticDigest,
8};
9use std::ops::Deref;
10
11pub use basic_block::BasicBlock;
12pub use raw::RawProgram;
13pub use seq_block::SeqBlock;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17#[repr(transparent)]
18pub struct RecursionProgram<F>(RootProgram<F>);
19
20impl<F> RecursionProgram<F> {
21 pub unsafe fn new_unchecked(program: RootProgram<F>) -> Self {
44 Self(program)
45 }
46
47 pub fn into_inner(self) -> RootProgram<F> {
48 self.0
49 }
50
51 pub fn shape_mut(&mut self) -> &mut Option<RecursionShape> {
52 &mut self.0.shape
53 }
54}
55
56impl<F> Default for RecursionProgram<F> {
57 fn default() -> Self {
58 unsafe { Self::new_unchecked(RootProgram::default()) }
60 }
61}
62
63impl<F> Deref for RecursionProgram<F> {
64 type Target = RootProgram<F>;
65
66 fn deref(&self) -> &Self::Target {
67 &self.0
68 }
69}
70
71impl<F: Field> MachineProgram<F> for RecursionProgram<F> {
72 fn pc_start(&self) -> F {
73 F::zero()
74 }
75
76 fn initial_global_cumulative_sum(&self) -> SepticDigest<F> {
77 SepticDigest::<F>::zero()
78 }
79}
80
81impl<F: Field> RecursionProgram<F> {
82 #[inline]
83 pub fn fixed_log2_rows<A: MachineAir<F>>(&self, air: &A) -> Option<usize> {
84 self.0
85 .shape
86 .as_ref()
87 .map(|shape| {
88 shape
89 .inner
90 .get(&air.name())
91 .unwrap_or_else(|| panic!("Chip {} not found in specified shape", air.name()))
92 })
93 .copied()
94 }
95}
96
97#[cfg(any(test, feature = "program_validation"))]
98pub use validation::*;
99
100#[cfg(any(test, feature = "program_validation"))]
101mod validation {
102 use super::*;
103
104 use std::{fmt::Debug, iter, mem};
105
106 use p3_field::PrimeField32;
107 use range_set_blaze::{MultiwayRangeSetBlazeRef, RangeSetBlaze};
108 use smallvec::{smallvec, SmallVec};
109 use thiserror::Error;
110
111 #[derive(Error, Debug)]
112 pub enum StructureError<F: Debug> {
113 #[error("tried to read from uninitialized address {addr:?}. instruction: {instr:?}")]
114 ReadFromUninit { addr: Address<F>, instr: Instruction<F> },
115 }
116
117 #[derive(Error, Debug)]
118 pub enum SummaryError {
119 #[error("`total_memory` is insufficient. configured: {configured}. required: {required}")]
120 OutOfMemory { configured: usize, required: usize },
121 }
122
123 #[derive(Error, Debug)]
124 pub enum ValidationError<F: Debug> {
125 Structure(#[from] StructureError<F>),
126 Summary(#[from] SummaryError),
127 }
128
129 impl<F: PrimeField32> RecursionProgram<F> {
130 pub fn try_new_unmodified(
132 program: RootProgram<F>,
133 ) -> Result<Self, Box<ValidationError<F>>> {
134 let written_addrs = try_written_addrs(smallvec![], &program.inner)
135 .map_err(|e| ValidationError::from(*e))?;
136 if let Some(required) = written_addrs.last().map(|x| x as usize + 1) {
137 let configured = program.total_memory;
138 if required > configured {
139 Err(Box::new(SummaryError::OutOfMemory { configured, required }.into()))?
140 }
141 }
142 Ok(unsafe { Self::new_unchecked(program) })
144 }
145
146 pub fn try_new(mut program: RootProgram<F>) -> Result<Self, Box<StructureError<F>>> {
148 let written_addrs = try_written_addrs(smallvec![], &program.inner)?;
149 program.total_memory = written_addrs.last().map(|x| x as usize + 1).unwrap_or_default();
150 Ok(unsafe { Self::new_unchecked(program) })
152 }
153 }
154
155 fn try_written_addrs<F: PrimeField32>(
156 readable_addrs: SmallVec<[&RangeSetBlaze<u32>; 3]>,
157 program: &RawProgram<Instruction<F>>,
158 ) -> Result<RangeSetBlaze<u32>, Box<StructureError<F>>> {
159 let mut written_addrs = RangeSetBlaze::<u32>::new();
160 for block in &program.seq_blocks {
161 match block {
162 SeqBlock::Basic(basic_block) => {
163 for instr in &basic_block.instrs {
164 let (inputs, outputs) = instr.io_addrs();
165 inputs.into_iter().try_for_each(|i| {
166 let i_u32 = i.0.as_canonical_u32();
167 iter::once(&written_addrs)
168 .chain(readable_addrs.iter().copied())
169 .any(|s| s.contains(i_u32))
170 .then_some(())
171 .ok_or_else(|| {
172 Box::new(StructureError::ReadFromUninit {
173 addr: i,
174 instr: instr.clone(),
175 })
176 })
177 })?;
178 written_addrs.extend(outputs.iter().map(|o| o.0.as_canonical_u32()));
179 }
180 }
181 SeqBlock::Parallel(programs) => {
182 let par_written_addrs = programs
183 .iter()
184 .map(|subprogram| {
185 let sub_readable_addrs = iter::once(&written_addrs)
186 .chain(readable_addrs.iter().copied())
187 .collect();
188
189 try_written_addrs(sub_readable_addrs, subprogram)
190 })
191 .collect::<Result<Vec<_>, _>>()?;
192 written_addrs =
193 iter::once(mem::take(&mut written_addrs)).chain(par_written_addrs).union();
194 }
195 }
196 }
197 Ok(written_addrs)
198 }
199
200 impl<F: PrimeField32> RootProgram<F> {
201 pub fn validate(self) -> Result<RecursionProgram<F>, Box<StructureError<F>>> {
202 RecursionProgram::try_new(self)
203 }
204 }
205
206 #[cfg(test)]
207 pub fn linear_program<F: PrimeField32>(
208 instrs: Vec<Instruction<F>>,
209 ) -> Result<RecursionProgram<F>, Box<StructureError<F>>> {
210 RootProgram {
211 inner: RawProgram { seq_blocks: vec![SeqBlock::Basic(BasicBlock { instrs })] },
212 total_memory: 0, shape: None,
214 }
215 .validate()
216 }
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct RootProgram<F> {
221 pub inner: raw::RawProgram<Instruction<F>>,
222 pub total_memory: usize,
223 pub shape: Option<RecursionShape>,
224}
225
226impl<F> Default for RootProgram<F> {
228 fn default() -> Self {
229 Self {
230 inner: Default::default(),
231 total_memory: Default::default(),
232 shape: Default::default(),
233 }
234 }
235}
236
237pub mod raw {
238 use std::iter::Flatten;
239
240 use super::*;
241
242 #[derive(Debug, Clone, Serialize, Deserialize)]
243 pub struct RawProgram<T> {
244 pub seq_blocks: Vec<SeqBlock<T>>,
245 }
246
247 impl<T> Default for RawProgram<T> {
249 fn default() -> Self {
250 Self { seq_blocks: Default::default() }
251 }
252 }
253
254 impl<T> RawProgram<T> {
255 pub fn iter(&self) -> impl Iterator<Item = &'_ T> {
256 self.seq_blocks.iter().flatten()
257 }
258 pub fn iter_mut(&mut self) -> impl Iterator<Item = &'_ mut T> {
259 self.seq_blocks.iter_mut().flatten()
260 }
261 }
262
263 impl<T> IntoIterator for RawProgram<T> {
264 type Item = T;
265
266 type IntoIter = Flatten<<Vec<SeqBlock<T>> as IntoIterator>::IntoIter>;
267
268 fn into_iter(self) -> Self::IntoIter {
269 self.seq_blocks.into_iter().flatten()
270 }
271 }
272
273 impl<'a, T> IntoIterator for &'a RawProgram<T> {
274 type Item = &'a T;
275
276 type IntoIter = Flatten<<&'a Vec<SeqBlock<T>> as IntoIterator>::IntoIter>;
277
278 fn into_iter(self) -> Self::IntoIter {
279 self.seq_blocks.iter().flatten()
280 }
281 }
282
283 impl<'a, T> IntoIterator for &'a mut RawProgram<T> {
284 type Item = &'a mut T;
285
286 type IntoIter = Flatten<<&'a mut Vec<SeqBlock<T>> as IntoIterator>::IntoIter>;
287
288 fn into_iter(self) -> Self::IntoIter {
289 self.seq_blocks.iter_mut().flatten()
290 }
291 }
292}
293
294pub mod seq_block {
295 use std::iter::Flatten;
296
297 use super::*;
298
299 #[derive(Debug, Clone, Serialize, Deserialize)]
301 pub enum SeqBlock<T> {
302 Basic(BasicBlock<T>),
304 Parallel(Vec<RawProgram<T>>),
306 }
307
308 impl<T> SeqBlock<T> {
309 pub fn iter(&self) -> Iter<'_, T> {
310 self.into_iter()
311 }
312
313 pub fn iter_mut(&mut self) -> IterMut<'_, T> {
314 self.into_iter()
315 }
316 }
317
318 #[derive(Debug)]
320 pub enum Iter<'a, T> {
321 Basic(<&'a Vec<T> as IntoIterator>::IntoIter),
322 Parallel(Box<Flatten<<&'a Vec<RawProgram<T>> as IntoIterator>::IntoIter>>),
323 }
324
325 impl<'a, T> Iterator for Iter<'a, T> {
326 type Item = &'a T;
327
328 fn next(&mut self) -> Option<Self::Item> {
329 match self {
330 Iter::Basic(it) => it.next(),
331 Iter::Parallel(it) => it.next(),
332 }
333 }
334 }
335
336 impl<'a, T> IntoIterator for &'a SeqBlock<T> {
337 type Item = &'a T;
338
339 type IntoIter = Iter<'a, T>;
340
341 fn into_iter(self) -> Self::IntoIter {
342 match self {
343 SeqBlock::Basic(basic_block) => Iter::Basic(basic_block.instrs.iter()),
344 SeqBlock::Parallel(vec) => Iter::Parallel(Box::new(vec.iter().flatten())),
345 }
346 }
347 }
348
349 #[derive(Debug)]
350 pub enum IterMut<'a, T> {
351 Basic(<&'a mut Vec<T> as IntoIterator>::IntoIter),
352 Parallel(Box<Flatten<<&'a mut Vec<RawProgram<T>> as IntoIterator>::IntoIter>>),
353 }
354
355 impl<'a, T> Iterator for IterMut<'a, T> {
356 type Item = &'a mut T;
357
358 fn next(&mut self) -> Option<Self::Item> {
359 match self {
360 IterMut::Basic(it) => it.next(),
361 IterMut::Parallel(it) => it.next(),
362 }
363 }
364 }
365
366 impl<'a, T> IntoIterator for &'a mut SeqBlock<T> {
367 type Item = &'a mut T;
368
369 type IntoIter = IterMut<'a, T>;
370
371 fn into_iter(self) -> Self::IntoIter {
372 match self {
373 SeqBlock::Basic(basic_block) => IterMut::Basic(basic_block.instrs.iter_mut()),
374 SeqBlock::Parallel(vec) => IterMut::Parallel(Box::new(vec.iter_mut().flatten())),
375 }
376 }
377 }
378
379 #[derive(Debug, Clone)]
380 pub enum IntoIter<T> {
381 Basic(<Vec<T> as IntoIterator>::IntoIter),
382 Parallel(Box<Flatten<<Vec<RawProgram<T>> as IntoIterator>::IntoIter>>),
383 }
384
385 impl<T> Iterator for IntoIter<T> {
386 type Item = T;
387
388 fn next(&mut self) -> Option<Self::Item> {
389 match self {
390 IntoIter::Basic(it) => it.next(),
391 IntoIter::Parallel(it) => it.next(),
392 }
393 }
394 }
395
396 impl<T> IntoIterator for SeqBlock<T> {
397 type Item = T;
398
399 type IntoIter = IntoIter<T>;
400
401 fn into_iter(self) -> Self::IntoIter {
402 match self {
403 SeqBlock::Basic(basic_block) => IntoIter::Basic(basic_block.instrs.into_iter()),
404 SeqBlock::Parallel(vec) => IntoIter::Parallel(Box::new(vec.into_iter().flatten())),
405 }
406 }
407 }
408}
409
410pub mod basic_block {
411 use super::*;
412
413 #[derive(Debug, Clone, Serialize, Deserialize)]
414 pub struct BasicBlock<T> {
415 pub instrs: Vec<T>,
416 }
417
418 impl<T> Default for BasicBlock<T> {
420 fn default() -> Self {
421 Self { instrs: Default::default() }
422 }
423 }
424}