sp1_recursion_core/
shape.rs

1#![allow(clippy::never_loop)]
2
3use std::marker::PhantomData;
4
5use hashbrown::HashMap;
6
7use itertools::Itertools;
8use p3_field::{extension::BinomiallyExtendable, PrimeField32};
9use serde::{Deserialize, Serialize};
10use sp1_stark::{air::MachineAir, shape::OrderedShape};
11
12use crate::{
13    chips::{
14        alu_base::BaseAluChip,
15        alu_ext::ExtAluChip,
16        batch_fri::BatchFRIChip,
17        exp_reverse_bits::ExpReverseBitsLenChip,
18        mem::{MemoryConstChip, MemoryVarChip},
19        poseidon2_wide::Poseidon2WideChip,
20        public_values::{PublicValuesChip, PUB_VALUES_LOG_HEIGHT},
21        select::SelectChip,
22    },
23    machine::RecursionAir,
24    RecursionProgram, D,
25};
26
27#[derive(Debug, Clone, Default, Serialize, Deserialize)]
28pub struct RecursionShape {
29    pub(crate) inner: HashMap<String, usize>,
30}
31
32impl RecursionShape {
33    pub fn clone_into_hash_map(&self) -> HashMap<String, usize> {
34        self.inner.clone()
35    }
36}
37
38impl From<HashMap<String, usize>> for RecursionShape {
39    fn from(value: HashMap<String, usize>) -> Self {
40        Self { inner: value }
41    }
42}
43
44pub struct RecursionShapeConfig<F, A> {
45    allowed_shapes: Vec<HashMap<String, usize>>,
46    _marker: PhantomData<(F, A)>,
47}
48
49impl<F: PrimeField32 + BinomiallyExtendable<D>, const DEGREE: usize>
50    RecursionShapeConfig<F, RecursionAir<F, DEGREE>>
51{
52    pub fn fix_shape(&self, program: &mut RecursionProgram<F>) {
53        let heights = RecursionAir::<F, DEGREE>::heights(program);
54
55        let mut closest_shape = None;
56
57        for shape in self.allowed_shapes.iter() {
58            // If any of the heights is greater than the shape, continue.
59            let mut valid = true;
60            for (name, height) in heights.iter() {
61                if *height > (1 << shape.get(name).unwrap()) {
62                    valid = false;
63                }
64            }
65
66            if !valid {
67                continue;
68            }
69
70            closest_shape = Some(shape.clone());
71            break;
72        }
73
74        if let Some(shape) = closest_shape {
75            let shape = RecursionShape { inner: shape };
76            *program.shape_mut() = Some(shape);
77        } else {
78            panic!("no shape found for heights: {heights:?}");
79        }
80    }
81
82    pub fn get_all_shape_combinations(
83        &self,
84        batch_size: usize,
85    ) -> impl Iterator<Item = Vec<OrderedShape>> + '_ {
86        (0..batch_size)
87            .map(|_| {
88                self.allowed_shapes
89                    .iter()
90                    .cloned()
91                    .map(|map| map.into_iter().collect::<OrderedShape>())
92            })
93            .multi_cartesian_product()
94    }
95
96    pub fn union_config_with_extra_room(&self) -> Self {
97        let mut map = HashMap::new();
98        for shape in self.allowed_shapes.clone() {
99            for key in shape.keys() {
100                let current = map.get(key).unwrap_or(&0);
101                map.insert(key.clone(), *current.max(shape.get(key).unwrap()));
102            }
103        }
104        map.values_mut().for_each(|x| *x += 2);
105        map.insert("PublicValues".to_string(), 4);
106        Self { allowed_shapes: vec![map], _marker: PhantomData }
107    }
108
109    pub fn from_hash_map(hash_map: &HashMap<String, usize>) -> Self {
110        Self { allowed_shapes: vec![hash_map.clone()], _marker: PhantomData }
111    }
112
113    pub fn first(&self) -> Option<&HashMap<String, usize>> {
114        self.allowed_shapes.first()
115    }
116}
117
118impl<F: PrimeField32 + BinomiallyExtendable<D>, const DEGREE: usize> Default
119    for RecursionShapeConfig<F, RecursionAir<F, DEGREE>>
120{
121    fn default() -> Self {
122        // Get the names of all the recursion airs to make the shape specification more readable.
123        let mem_const = RecursionAir::<F, DEGREE>::MemoryConst(MemoryConstChip::default()).name();
124        let mem_var = RecursionAir::<F, DEGREE>::MemoryVar(MemoryVarChip::default()).name();
125        let base_alu = RecursionAir::<F, DEGREE>::BaseAlu(BaseAluChip).name();
126        let ext_alu = RecursionAir::<F, DEGREE>::ExtAlu(ExtAluChip).name();
127        let poseidon2_wide =
128            RecursionAir::<F, DEGREE>::Poseidon2Wide(Poseidon2WideChip::<DEGREE>).name();
129        let batch_fri = RecursionAir::<F, DEGREE>::BatchFRI(BatchFRIChip::<DEGREE>).name();
130        let select = RecursionAir::<F, DEGREE>::Select(SelectChip).name();
131        let exp_reverse_bits_len =
132            RecursionAir::<F, DEGREE>::ExpReverseBitsLen(ExpReverseBitsLenChip::<DEGREE>).name();
133        let public_values = RecursionAir::<F, DEGREE>::PublicValues(PublicValuesChip).name();
134
135        // Specify allowed shapes.
136
137        let allowed_shapes = [
138            // Fastest shape.
139            [
140                (mem_var.clone(), 19),
141                (select.clone(), 19),
142                (mem_const.clone(), 17),
143                (batch_fri.clone(), 19),
144                (base_alu.clone(), 16),
145                (ext_alu.clone(), 16),
146                (exp_reverse_bits_len.clone(), 18),
147                (poseidon2_wide.clone(), 17),
148                (public_values.clone(), PUB_VALUES_LOG_HEIGHT),
149            ],
150            // Second fastest shape.
151            [
152                (mem_var.clone(), 20),
153                (select.clone(), 20),
154                (mem_const.clone(), 18),
155                (batch_fri.clone(), 21),
156                (base_alu.clone(), 16),
157                (ext_alu.clone(), 19),
158                (exp_reverse_bits_len.clone(), 18),
159                (poseidon2_wide.clone(), 18),
160                (public_values.clone(), PUB_VALUES_LOG_HEIGHT),
161            ],
162        ]
163        .map(HashMap::from)
164        .to_vec();
165        Self { allowed_shapes, _marker: PhantomData }
166    }
167}