1use crate::{analyzed::AnalyzedInstruction, shape::RecursionShape, *};
2use serde::{Deserialize, Serialize};
3use slop_algebra::Field;
4use sp1_hypercube::{air::MachineProgram, septic_digest::SepticDigest};
5use std::ops::{Deref, DerefMut};
6
7pub use basic_block::BasicBlock;
8pub use raw::RawProgram;
9pub use seq_block::SeqBlock;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13#[repr(transparent)]
14pub struct RecursionProgram<F>(RootProgram<F>);
15
16impl<F> RecursionProgram<F> {
17 pub unsafe fn new_unchecked(program: RootProgram<F>) -> Self {
42 Self(program)
43 }
44
45 pub fn into_inner(self) -> RootProgram<F> {
46 self.0
47 }
48}
49
50impl<F> Default for RecursionProgram<F> {
51 fn default() -> Self {
52 unsafe { Self::new_unchecked(RootProgram::default()) }
54 }
55}
56
57impl<F> Deref for RecursionProgram<F> {
58 type Target = RootProgram<F>;
59
60 fn deref(&self) -> &Self::Target {
61 &self.0
62 }
63}
64
65impl<F> DerefMut for RecursionProgram<F> {
66 fn deref_mut(&mut self) -> &mut Self::Target {
67 &mut self.0
68 }
69}
70
71impl<F: Field> MachineProgram<F> for RecursionProgram<F> {
72 fn pc_start(&self) -> [F; 3] {
73 [F::zero(), F::zero(), F::zero()]
74 }
75
76 fn initial_global_cumulative_sum(&self) -> SepticDigest<F> {
77 SepticDigest::<F>::zero()
78 }
79
80 fn enable_untrusted_programs(&self) -> F {
81 F::zero()
82 }
83}
84
85#[cfg(any(test, feature = "program_validation"))]
86pub use validation::*;
87
88#[cfg(any(test, feature = "program_validation"))]
89mod validation {
90 use super::*;
91
92 use std::{fmt::Debug, iter, mem};
93
94 use range_set_blaze::{MultiwayRangeSetBlazeRef, RangeSetBlaze};
95 use slop_algebra::PrimeField32;
96 use smallvec::{smallvec, SmallVec};
97 use thiserror::Error;
98
99 #[derive(Error, Debug)]
100 pub enum StructureError<F: Debug> {
101 #[error("tried to read from uninitialized address {addr:?}. instruction: {instr:?}")]
102 ReadFromUninit { addr: Address<F>, instr: Instruction<F> },
103 }
104
105 #[derive(Error, Debug)]
106 pub enum SummaryError {
107 #[error("`total_memory` is insufficient. configured: {configured}. required: {required}")]
108 OutOfMemory { configured: usize, required: usize },
109 }
110
111 #[derive(Error, Debug)]
112 pub enum ValidationError<F: Debug> {
113 Structure(#[from] StructureError<F>),
114 Summary(#[from] SummaryError),
115 }
116
117 impl<F: PrimeField32> RecursionProgram<F> {
118 pub fn try_new_unmodified(
120 program: RootProgram<F>,
121 ) -> Result<Self, Box<ValidationError<F>>> {
122 let written_addrs = try_written_addrs(smallvec![], &program.inner)
123 .map_err(|e| ValidationError::from(*e))?;
124 if let Some(required) = written_addrs.last().map(|x| x as usize + 1) {
125 let configured = program.total_memory;
126 if required > configured {
127 Err(Box::new(SummaryError::OutOfMemory { configured, required }.into()))?
128 }
129 }
130 Ok(unsafe { Self::new_unchecked(program) })
132 }
133
134 pub fn try_new(mut program: RootProgram<F>) -> Result<Self, Box<StructureError<F>>> {
136 let written_addrs = try_written_addrs(smallvec![], &program.inner)?;
137 program.total_memory = written_addrs.last().map(|x| x as usize + 1).unwrap_or_default();
138 Ok(unsafe { Self::new_unchecked(program) })
140 }
141 }
142
143 fn try_written_addrs<F: PrimeField32>(
144 readable_addrs: SmallVec<[&RangeSetBlaze<u32>; 3]>,
145 program: &RawProgram<AnalyzedInstruction<F>>,
146 ) -> Result<RangeSetBlaze<u32>, Box<StructureError<F>>> {
147 let mut written_addrs = RangeSetBlaze::<u32>::new();
148 for block in &program.seq_blocks {
149 match block {
150 SeqBlock::Basic(basic_block) => {
151 for instr in &basic_block.instrs {
152 let (inputs, outputs) = instr.inner.io_addrs();
153 inputs.into_iter().try_for_each(|i| {
154 let i_u32 = i.0.as_canonical_u32();
155 iter::once(&written_addrs)
156 .chain(readable_addrs.iter().copied())
157 .any(|s| s.contains(i_u32))
158 .then_some(())
159 .ok_or_else(|| {
160 Box::new(StructureError::ReadFromUninit {
161 addr: i,
162 instr: instr.inner.clone(),
163 })
164 })
165 })?;
166 written_addrs.extend(outputs.iter().map(|o| o.0.as_canonical_u32()));
167 }
168 }
169 SeqBlock::Parallel(programs) => {
170 let par_written_addrs = programs
171 .iter()
172 .map(|subprogram| {
173 let sub_readable_addrs = iter::once(&written_addrs)
174 .chain(readable_addrs.iter().copied())
175 .collect();
176
177 try_written_addrs(sub_readable_addrs, subprogram)
178 })
179 .collect::<Result<Vec<_>, _>>()?;
180 written_addrs =
181 iter::once(mem::take(&mut written_addrs)).chain(par_written_addrs).union();
182 }
183 }
184 }
185 Ok(written_addrs)
186 }
187
188 impl<F: PrimeField32> RootProgram<F> {
189 pub fn validate(self) -> Result<RecursionProgram<F>, Box<StructureError<F>>> {
190 RecursionProgram::try_new(self)
191 }
192 }
193
194 pub fn linear_program<F: PrimeField32>(
195 instrs: Vec<Instruction<F>>,
196 ) -> Result<RecursionProgram<F>, Box<StructureError<F>>> {
197 let (analyzed, counts) =
198 RawProgram { seq_blocks: vec![SeqBlock::Basic(BasicBlock { instrs })] }.analyze();
199
200 RootProgram { inner: analyzed, total_memory: 0, shape: None, event_counts: counts }
201 .validate()
202 }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct RootProgram<F> {
207 pub inner: raw::RawProgram<AnalyzedInstruction<F>>,
208 pub total_memory: usize,
209 pub shape: Option<RecursionShape<F>>,
210 pub event_counts: RecursionAirEventCount,
211}
212
213impl<F> Default for RootProgram<F> {
215 fn default() -> Self {
216 Self {
217 inner: Default::default(),
218 total_memory: Default::default(),
219 shape: None,
220 event_counts: Default::default(),
221 }
222 }
223}
224
225pub mod raw {
226 use std::iter::Flatten;
227
228 use super::*;
229
230 #[derive(Debug, Clone, Serialize, Deserialize)]
231 pub struct RawProgram<T> {
232 pub seq_blocks: Vec<SeqBlock<T>>,
233 }
234
235 impl<T> Default for RawProgram<T> {
237 fn default() -> Self {
238 Self { seq_blocks: Default::default() }
239 }
240 }
241
242 impl<T> RawProgram<T> {
243 pub fn iter(&self) -> impl Iterator<Item = &'_ T> {
244 self.seq_blocks.iter().flatten()
245 }
246 pub fn iter_mut(&mut self) -> impl Iterator<Item = &'_ mut T> {
247 self.seq_blocks.iter_mut().flatten()
248 }
249 }
250
251 impl<T> IntoIterator for RawProgram<T> {
252 type Item = T;
253
254 type IntoIter = Flatten<<Vec<SeqBlock<T>> as IntoIterator>::IntoIter>;
255
256 fn into_iter(self) -> Self::IntoIter {
257 self.seq_blocks.into_iter().flatten()
258 }
259 }
260
261 impl<'a, T> IntoIterator for &'a RawProgram<T> {
262 type Item = &'a T;
263
264 type IntoIter = Flatten<<&'a Vec<SeqBlock<T>> as IntoIterator>::IntoIter>;
265
266 fn into_iter(self) -> Self::IntoIter {
267 self.seq_blocks.iter().flatten()
268 }
269 }
270
271 impl<'a, T> IntoIterator for &'a mut RawProgram<T> {
272 type Item = &'a mut T;
273
274 type IntoIter = Flatten<<&'a mut Vec<SeqBlock<T>> as IntoIterator>::IntoIter>;
275
276 fn into_iter(self) -> Self::IntoIter {
277 self.seq_blocks.iter_mut().flatten()
278 }
279 }
280}
281
282pub mod seq_block {
283 use std::iter::Flatten;
284
285 use super::*;
286
287 #[derive(Debug, Clone, Serialize, Deserialize)]
289 pub enum SeqBlock<T> {
290 Basic(BasicBlock<T>),
292 Parallel(Vec<RawProgram<T>>),
294 }
295
296 impl<T> SeqBlock<T> {
297 pub fn iter(&self) -> Iter<'_, T> {
298 self.into_iter()
299 }
300
301 pub fn iter_mut(&mut self) -> IterMut<'_, T> {
302 self.into_iter()
303 }
304 }
305
306 #[derive(Debug)]
308 pub enum Iter<'a, T> {
309 Basic(<&'a Vec<T> as IntoIterator>::IntoIter),
310 Parallel(Box<Flatten<<&'a Vec<RawProgram<T>> as IntoIterator>::IntoIter>>),
311 }
312
313 impl<'a, T> Iterator for Iter<'a, T> {
314 type Item = &'a T;
315
316 fn next(&mut self) -> Option<Self::Item> {
317 match self {
318 Iter::Basic(it) => it.next(),
319 Iter::Parallel(it) => it.next(),
320 }
321 }
322 }
323
324 impl<'a, T> IntoIterator for &'a SeqBlock<T> {
325 type Item = &'a T;
326
327 type IntoIter = Iter<'a, T>;
328
329 fn into_iter(self) -> Self::IntoIter {
330 match self {
331 SeqBlock::Basic(basic_block) => Iter::Basic(basic_block.instrs.iter()),
332 SeqBlock::Parallel(vec) => Iter::Parallel(Box::new(vec.iter().flatten())),
333 }
334 }
335 }
336
337 #[derive(Debug)]
338 pub enum IterMut<'a, T> {
339 Basic(<&'a mut Vec<T> as IntoIterator>::IntoIter),
340 Parallel(Box<Flatten<<&'a mut Vec<RawProgram<T>> as IntoIterator>::IntoIter>>),
341 }
342
343 impl<'a, T> Iterator for IterMut<'a, T> {
344 type Item = &'a mut T;
345
346 fn next(&mut self) -> Option<Self::Item> {
347 match self {
348 IterMut::Basic(it) => it.next(),
349 IterMut::Parallel(it) => it.next(),
350 }
351 }
352 }
353
354 impl<'a, T> IntoIterator for &'a mut SeqBlock<T> {
355 type Item = &'a mut T;
356
357 type IntoIter = IterMut<'a, T>;
358
359 fn into_iter(self) -> Self::IntoIter {
360 match self {
361 SeqBlock::Basic(basic_block) => IterMut::Basic(basic_block.instrs.iter_mut()),
362 SeqBlock::Parallel(vec) => IterMut::Parallel(Box::new(vec.iter_mut().flatten())),
363 }
364 }
365 }
366
367 #[derive(Debug, Clone)]
368 pub enum IntoIter<T> {
369 Basic(<Vec<T> as IntoIterator>::IntoIter),
370 Parallel(Box<Flatten<<Vec<RawProgram<T>> as IntoIterator>::IntoIter>>),
371 }
372
373 impl<T> Iterator for IntoIter<T> {
374 type Item = T;
375
376 fn next(&mut self) -> Option<Self::Item> {
377 match self {
378 IntoIter::Basic(it) => it.next(),
379 IntoIter::Parallel(it) => it.next(),
380 }
381 }
382 }
383
384 impl<T> IntoIterator for SeqBlock<T> {
385 type Item = T;
386
387 type IntoIter = IntoIter<T>;
388
389 fn into_iter(self) -> Self::IntoIter {
390 match self {
391 SeqBlock::Basic(basic_block) => IntoIter::Basic(basic_block.instrs.into_iter()),
392 SeqBlock::Parallel(vec) => IntoIter::Parallel(Box::new(vec.into_iter().flatten())),
393 }
394 }
395 }
396}
397
398pub mod basic_block {
399 use super::*;
400
401 #[derive(Debug, Clone, Serialize, Deserialize)]
402 pub struct BasicBlock<T> {
403 pub instrs: Vec<T>,
404 }
405
406 impl<T> Default for BasicBlock<T> {
408 fn default() -> Self {
409 Self { instrs: Default::default() }
410 }
411 }
412}