Skip to main content

sp1_hypercube/
machine.rs

1use derive_where::derive_where;
2use slop_algebra::Field;
3use std::collections::BTreeSet;
4
5use crate::{air::MachineAir, Chip, MachineRecord};
6
7/// A shape for a machine.
8#[derive_where(Debug; A: MachineAir<F>)]
9#[derive_where(Clone)]
10pub struct MachineShape<F: Field, A> {
11    /// The chip clusters.
12    pub chip_clusters: Vec<BTreeSet<Chip<F, A>>>,
13}
14
15impl<F: Field, A: MachineAir<F>> MachineShape<F, A> {
16    /// Create a single shape that always includes all the chips.
17    #[must_use]
18    pub fn all(chips: &[Chip<F, A>]) -> Self {
19        let chip_clusters = vec![chips.iter().cloned().collect()];
20        Self { chip_clusters }
21    }
22
23    /// Create a new shape from a list of chip clusters.
24    #[must_use]
25    pub const fn new(chip_clusters: Vec<BTreeSet<Chip<F, A>>>) -> Self {
26        Self { chip_clusters }
27    }
28
29    /// Returns the smallest shape cluster that contains all the chips with given names.
30    #[must_use]
31    pub fn smallest_cluster(&self, chips: &BTreeSet<Chip<F, A>>) -> Option<&BTreeSet<Chip<F, A>>> {
32        self.chip_clusters
33            .iter()
34            .filter(|cluster| chips.is_subset(cluster))
35            .min_by_key(|cluster| cluster.len())
36    }
37}
38
39/// A STARK for proving RISC-V execution.
40#[derive_where(Debug; A: MachineAir<F>)]
41#[derive_where(Clone)]
42pub struct Machine<F: Field, A> {
43    /// The chips that make up the RISC-V STARK machine, in order of their execution.
44    chips: Vec<Chip<F, A>>,
45    /// The number of public values elements that the machine uses
46    num_pv_elts: usize,
47    /// The shape of the machine.
48    shape: MachineShape<F, A>,
49}
50
51impl<F, A> Machine<F, A>
52where
53    F: Field,
54    A: MachineAir<F>,
55{
56    /// Creates a new [`StarkMachine`].
57    #[must_use]
58    pub const fn new(
59        chips: Vec<Chip<F, A>>,
60        num_pv_elts: usize,
61        shape: MachineShape<F, A>,
62    ) -> Self {
63        Self { chips, num_pv_elts, shape }
64    }
65
66    /// Returns the chips in the machine.
67    #[must_use]
68    pub fn chips(&self) -> &[Chip<F, A>] {
69        &self.chips
70    }
71
72    /// Returns the number of public values elements.
73    #[must_use]
74    pub const fn num_pv_elts(&self) -> usize {
75        self.num_pv_elts
76    }
77
78    /// Returns the shape of the machine.
79    #[must_use]
80    pub const fn shape(&self) -> &MachineShape<F, A> {
81        &self.shape
82    }
83
84    /// Returns the smallest shape cluster that contains all the chips with given names.
85    #[must_use]
86    pub fn smallest_cluster(&self, chips: &BTreeSet<Chip<F, A>>) -> Option<&BTreeSet<Chip<F, A>>> {
87        self.shape.smallest_cluster(chips)
88    }
89
90    /// Generates the dependencies of the given records.
91    #[allow(clippy::needless_for_each)]
92    pub fn generate_dependencies<'a>(
93        &self,
94        records: impl Iterator<Item = &'a mut A::Record>,
95        chips_filter: Option<&[String]>,
96    ) {
97        let chips = self
98            .chips
99            .iter()
100            .filter(|chip| {
101                if let Some(chips_filter) = chips_filter {
102                    chips_filter.contains(&chip.name().to_string())
103                } else {
104                    true
105                }
106            })
107            .collect::<Vec<_>>();
108
109        records.for_each(|record| {
110            chips.iter().for_each(|chip| {
111                let mut output = A::Record::default();
112                chip.generate_dependencies(record, &mut output);
113                record.append(&mut output);
114            });
115        });
116    }
117}