sp1_recursion_core/runtime/
program.rs1use crate::*;
2use p3_field::Field;
3use serde::{Deserialize, Serialize};
4use shape::RecursionShape;
5use sp1_stark::air::{MachineAir, MachineProgram};
6use sp1_stark::septic_digest::SepticDigest;
7use std::ops::Deref;
8
9pub use basic_block::BasicBlock;
10pub use raw::RawProgram;
11pub use seq_block::SeqBlock;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15#[repr(transparent)]
16pub struct RecursionProgram<F>(RootProgram<F>);
17
18impl<F> RecursionProgram<F> {
19 pub unsafe fn new_unchecked(program: RootProgram<F>) -> Self {
37 Self(program)
38 }
39
40 pub fn into_inner(self) -> RootProgram<F> {
41 self.0
42 }
43
44 pub fn shape_mut(&mut self) -> &mut Option<RecursionShape> {
45 &mut self.0.shape
46 }
47}
48
49impl<F> Default for RecursionProgram<F> {
50 fn default() -> Self {
51 unsafe { Self::new_unchecked(RootProgram::default()) }
53 }
54}
55
56impl<F> Deref for RecursionProgram<F> {
57 type Target = RootProgram<F>;
58
59 fn deref(&self) -> &Self::Target {
60 &self.0
61 }
62}
63
64impl<F: Field> MachineProgram<F> for RecursionProgram<F> {
65 fn pc_start(&self) -> F {
66 F::zero()
67 }
68
69 fn initial_global_cumulative_sum(&self) -> SepticDigest<F> {
70 SepticDigest::<F>::zero()
71 }
72}
73
74impl<F: Field> RecursionProgram<F> {
75 #[inline]
76 pub fn fixed_log2_rows<A: MachineAir<F>>(&self, air: &A) -> Option<usize> {
77 self.0
78 .shape
79 .as_ref()
80 .map(|shape| {
81 shape
82 .inner
83 .get(&air.name())
84 .unwrap_or_else(|| panic!("Chip {} not found in specified shape", air.name()))
85 })
86 .copied()
87 }
88}
89
90#[cfg(any(test, feature = "program_validation"))]
91pub use validation::*;
92
93#[cfg(any(test, feature = "program_validation"))]
94mod validation {
95 use super::*;
96
97 use std::{fmt::Debug, iter, mem};
98
99 use p3_field::PrimeField32;
100 use range_set_blaze::{MultiwayRangeSetBlazeRef, RangeSetBlaze};
101 use smallvec::{smallvec, SmallVec};
102 use thiserror::Error;
103
104 #[derive(Error, Debug)]
105 pub enum StructureError<F: Debug> {
106 #[error("tried to read from uninitialized address {addr:?}. instruction: {instr:?}")]
107 ReadFromUninit { addr: Address<F>, instr: Instruction<F> },
108 }
109
110 #[derive(Error, Debug)]
111 pub enum SummaryError {
112 #[error("`total_memory` is insufficient. configured: {configured}. required: {required}")]
113 OutOfMemory { configured: usize, required: usize },
114 }
115
116 #[derive(Error, Debug)]
117 pub enum ValidationError<F: Debug> {
118 Structure(#[from] StructureError<F>),
119 Summary(#[from] SummaryError),
120 }
121
122 impl<F: PrimeField32> RecursionProgram<F> {
123 pub fn try_new_unmodified(
125 program: RootProgram<F>,
126 ) -> Result<Self, Box<ValidationError<F>>> {
127 let written_addrs = try_written_addrs(smallvec![], &program.inner)
128 .map_err(|e| ValidationError::from(*e))?;
129 if let Some(required) = written_addrs.last().map(|x| x as usize + 1) {
130 let configured = program.total_memory;
131 if required > configured {
132 Err(Box::new(SummaryError::OutOfMemory { configured, required }.into()))?
133 }
134 }
135 Ok(unsafe { Self::new_unchecked(program) })
137 }
138
139 pub fn try_new(mut program: RootProgram<F>) -> Result<Self, Box<StructureError<F>>> {
141 let written_addrs = try_written_addrs(smallvec![], &program.inner)?;
142 program.total_memory = written_addrs.last().map(|x| x as usize + 1).unwrap_or_default();
143 Ok(unsafe { Self::new_unchecked(program) })
145 }
146 }
147
148 fn try_written_addrs<F: PrimeField32>(
149 readable_addrs: SmallVec<[&RangeSetBlaze<u32>; 3]>,
150 program: &RawProgram<Instruction<F>>,
151 ) -> Result<RangeSetBlaze<u32>, Box<StructureError<F>>> {
152 let mut written_addrs = RangeSetBlaze::<u32>::new();
153 for block in &program.seq_blocks {
154 match block {
155 SeqBlock::Basic(basic_block) => {
156 for instr in &basic_block.instrs {
157 let (inputs, outputs) = instr.io_addrs();
158 inputs.into_iter().try_for_each(|i| {
159 let i_u32 = i.0.as_canonical_u32();
160 iter::once(&written_addrs)
161 .chain(readable_addrs.iter().copied())
162 .any(|s| s.contains(i_u32))
163 .then_some(())
164 .ok_or_else(|| {
165 Box::new(StructureError::ReadFromUninit {
166 addr: i,
167 instr: instr.clone(),
168 })
169 })
170 })?;
171 written_addrs.extend(outputs.iter().map(|o| o.0.as_canonical_u32()));
172 }
173 }
174 SeqBlock::Parallel(programs) => {
175 let par_written_addrs = programs
176 .iter()
177 .map(|subprogram| {
178 let sub_readable_addrs = iter::once(&written_addrs)
179 .chain(readable_addrs.iter().copied())
180 .collect();
181
182 try_written_addrs(sub_readable_addrs, subprogram)
183 })
184 .collect::<Result<Vec<_>, _>>()?;
185 written_addrs =
186 iter::once(mem::take(&mut written_addrs)).chain(par_written_addrs).union();
187 }
188 }
189 }
190 Ok(written_addrs)
191 }
192
193 impl<F: PrimeField32> RootProgram<F> {
194 pub fn validate(self) -> Result<RecursionProgram<F>, Box<StructureError<F>>> {
195 RecursionProgram::try_new(self)
196 }
197 }
198
199 #[cfg(test)]
200 pub fn linear_program<F: PrimeField32>(
201 instrs: Vec<Instruction<F>>,
202 ) -> Result<RecursionProgram<F>, Box<StructureError<F>>> {
203 RootProgram {
204 inner: RawProgram { seq_blocks: vec![SeqBlock::Basic(BasicBlock { instrs })] },
205 total_memory: 0, shape: None,
207 }
208 .validate()
209 }
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct RootProgram<F> {
214 pub inner: raw::RawProgram<Instruction<F>>,
215 pub total_memory: usize,
216 pub shape: Option<RecursionShape>,
217}
218
219impl<F> Default for RootProgram<F> {
221 fn default() -> Self {
222 Self {
223 inner: Default::default(),
224 total_memory: Default::default(),
225 shape: Default::default(),
226 }
227 }
228}
229
230pub mod raw {
231 use std::iter::Flatten;
232
233 use super::*;
234
235 #[derive(Debug, Clone, Serialize, Deserialize)]
236 pub struct RawProgram<T> {
237 pub seq_blocks: Vec<SeqBlock<T>>,
238 }
239
240 impl<T> Default for RawProgram<T> {
242 fn default() -> Self {
243 Self { seq_blocks: Default::default() }
244 }
245 }
246
247 impl<T> RawProgram<T> {
248 pub fn iter(&self) -> impl Iterator<Item = &'_ T> {
249 self.seq_blocks.iter().flatten()
250 }
251 pub fn iter_mut(&mut self) -> impl Iterator<Item = &'_ mut T> {
252 self.seq_blocks.iter_mut().flatten()
253 }
254 }
255
256 impl<T> IntoIterator for RawProgram<T> {
257 type Item = T;
258
259 type IntoIter = Flatten<<Vec<SeqBlock<T>> as IntoIterator>::IntoIter>;
260
261 fn into_iter(self) -> Self::IntoIter {
262 self.seq_blocks.into_iter().flatten()
263 }
264 }
265
266 impl<'a, T> IntoIterator for &'a RawProgram<T> {
267 type Item = &'a T;
268
269 type IntoIter = Flatten<<&'a Vec<SeqBlock<T>> as IntoIterator>::IntoIter>;
270
271 fn into_iter(self) -> Self::IntoIter {
272 self.seq_blocks.iter().flatten()
273 }
274 }
275
276 impl<'a, T> IntoIterator for &'a mut RawProgram<T> {
277 type Item = &'a mut T;
278
279 type IntoIter = Flatten<<&'a mut Vec<SeqBlock<T>> as IntoIterator>::IntoIter>;
280
281 fn into_iter(self) -> Self::IntoIter {
282 self.seq_blocks.iter_mut().flatten()
283 }
284 }
285}
286
287pub mod seq_block {
288 use std::iter::Flatten;
289
290 use super::*;
291
292 #[derive(Debug, Clone, Serialize, Deserialize)]
294 pub enum SeqBlock<T> {
295 Basic(BasicBlock<T>),
297 Parallel(Vec<RawProgram<T>>),
299 }
300
301 impl<T> SeqBlock<T> {
302 pub fn iter(&self) -> Iter<'_, T> {
303 self.into_iter()
304 }
305
306 pub fn iter_mut(&mut self) -> IterMut<'_, T> {
307 self.into_iter()
308 }
309 }
310
311 #[derive(Debug)]
313 pub enum Iter<'a, T> {
314 Basic(<&'a Vec<T> as IntoIterator>::IntoIter),
315 Parallel(Box<Flatten<<&'a Vec<RawProgram<T>> as IntoIterator>::IntoIter>>),
316 }
317
318 impl<'a, T> Iterator for Iter<'a, T> {
319 type Item = &'a T;
320
321 fn next(&mut self) -> Option<Self::Item> {
322 match self {
323 Iter::Basic(it) => it.next(),
324 Iter::Parallel(it) => it.next(),
325 }
326 }
327 }
328
329 impl<'a, T> IntoIterator for &'a SeqBlock<T> {
330 type Item = &'a T;
331
332 type IntoIter = Iter<'a, T>;
333
334 fn into_iter(self) -> Self::IntoIter {
335 match self {
336 SeqBlock::Basic(basic_block) => Iter::Basic(basic_block.instrs.iter()),
337 SeqBlock::Parallel(vec) => Iter::Parallel(Box::new(vec.iter().flatten())),
338 }
339 }
340 }
341
342 #[derive(Debug)]
343 pub enum IterMut<'a, T> {
344 Basic(<&'a mut Vec<T> as IntoIterator>::IntoIter),
345 Parallel(Box<Flatten<<&'a mut Vec<RawProgram<T>> as IntoIterator>::IntoIter>>),
346 }
347
348 impl<'a, T> Iterator for IterMut<'a, T> {
349 type Item = &'a mut T;
350
351 fn next(&mut self) -> Option<Self::Item> {
352 match self {
353 IterMut::Basic(it) => it.next(),
354 IterMut::Parallel(it) => it.next(),
355 }
356 }
357 }
358
359 impl<'a, T> IntoIterator for &'a mut SeqBlock<T> {
360 type Item = &'a mut T;
361
362 type IntoIter = IterMut<'a, T>;
363
364 fn into_iter(self) -> Self::IntoIter {
365 match self {
366 SeqBlock::Basic(basic_block) => IterMut::Basic(basic_block.instrs.iter_mut()),
367 SeqBlock::Parallel(vec) => IterMut::Parallel(Box::new(vec.iter_mut().flatten())),
368 }
369 }
370 }
371
372 #[derive(Debug, Clone)]
373 pub enum IntoIter<T> {
374 Basic(<Vec<T> as IntoIterator>::IntoIter),
375 Parallel(Box<Flatten<<Vec<RawProgram<T>> as IntoIterator>::IntoIter>>),
376 }
377
378 impl<T> Iterator for IntoIter<T> {
379 type Item = T;
380
381 fn next(&mut self) -> Option<Self::Item> {
382 match self {
383 IntoIter::Basic(it) => it.next(),
384 IntoIter::Parallel(it) => it.next(),
385 }
386 }
387 }
388
389 impl<T> IntoIterator for SeqBlock<T> {
390 type Item = T;
391
392 type IntoIter = IntoIter<T>;
393
394 fn into_iter(self) -> Self::IntoIter {
395 match self {
396 SeqBlock::Basic(basic_block) => IntoIter::Basic(basic_block.instrs.into_iter()),
397 SeqBlock::Parallel(vec) => IntoIter::Parallel(Box::new(vec.into_iter().flatten())),
398 }
399 }
400 }
401}
402
403pub mod basic_block {
404 use super::*;
405
406 #[derive(Debug, Clone, Serialize, Deserialize)]
407 pub struct BasicBlock<T> {
408 pub instrs: Vec<T>,
409 }
410
411 impl<T> Default for BasicBlock<T> {
413 fn default() -> Self {
414 Self { instrs: Default::default() }
415 }
416 }
417}