sp1_core_machine/shape/
shapeable.rs

1use hashbrown::HashMap;
2use itertools::Itertools;
3
4use sp1_core_executor::{events::PrecompileLocalMemory, ExecutionRecord, RiscvAirId};
5use sp1_stark::MachineRecord;
6
7use crate::memory::NUM_LOCAL_MEMORY_ENTRIES_PER_ROW;
8
9#[derive(Debug, Clone, Copy)]
10pub enum ShardKind {
11    PackedCore,
12    Core,
13    GlobalMemory,
14    Precompile,
15}
16
17pub trait Shapeable {
18    fn kind(&self) -> ShardKind;
19    fn shard(&self) -> u32;
20    fn log2_shard_size(&self) -> usize;
21    fn debug_stats(&self) -> HashMap<String, usize>;
22    fn core_heights(&self) -> Vec<(RiscvAirId, usize)>;
23    fn memory_heights(&self) -> Vec<(RiscvAirId, usize)>;
24    /// TODO. Returns all precompile events, assuming there is only one kind in `Self`.
25    /// The tuple is of the form `(height, (num_memory_local_events, num_global_events))`
26    fn precompile_heights(&self) -> impl Iterator<Item = (RiscvAirId, (usize, usize, usize))>;
27}
28
29impl Shapeable for ExecutionRecord {
30    fn kind(&self) -> ShardKind {
31        let contains_global_memory = !self.global_memory_initialize_events.is_empty() ||
32            !self.global_memory_finalize_events.is_empty();
33        match (self.contains_cpu(), contains_global_memory) {
34            (true, true) => ShardKind::PackedCore,
35            (true, false) => ShardKind::Core,
36            (false, true) => ShardKind::GlobalMemory,
37            (false, false) => ShardKind::Precompile,
38        }
39    }
40    fn shard(&self) -> u32 {
41        self.public_values.shard
42    }
43
44    fn log2_shard_size(&self) -> usize {
45        self.cpu_events.len().next_power_of_two().ilog2() as usize
46    }
47
48    fn debug_stats(&self) -> HashMap<String, usize> {
49        self.stats()
50    }
51
52    fn core_heights(&self) -> Vec<(RiscvAirId, usize)> {
53        vec![
54            (RiscvAirId::Cpu, self.cpu_events.len()),
55            (RiscvAirId::DivRem, self.divrem_events.len()),
56            (RiscvAirId::AddSub, self.add_events.len() + self.sub_events.len()),
57            (RiscvAirId::Bitwise, self.bitwise_events.len()),
58            (RiscvAirId::Mul, self.mul_events.len()),
59            (RiscvAirId::ShiftRight, self.shift_right_events.len()),
60            (RiscvAirId::ShiftLeft, self.shift_left_events.len()),
61            (RiscvAirId::Lt, self.lt_events.len()),
62            (
63                RiscvAirId::MemoryLocal,
64                self.get_local_mem_events()
65                    .chunks(NUM_LOCAL_MEMORY_ENTRIES_PER_ROW)
66                    .into_iter()
67                    .count(),
68            ),
69            (RiscvAirId::MemoryInstrs, self.memory_instr_events.len()),
70            (RiscvAirId::Auipc, self.auipc_events.len()),
71            (RiscvAirId::Branch, self.branch_events.len()),
72            (RiscvAirId::Jump, self.jump_events.len()),
73            (RiscvAirId::Global, self.global_interaction_events.len()),
74            (RiscvAirId::SyscallCore, self.syscall_events.len()),
75            (RiscvAirId::SyscallInstrs, self.syscall_events.len()),
76        ]
77    }
78
79    fn memory_heights(&self) -> Vec<(RiscvAirId, usize)> {
80        vec![
81            (RiscvAirId::MemoryGlobalInit, self.global_memory_initialize_events.len()),
82            (RiscvAirId::MemoryGlobalFinalize, self.global_memory_finalize_events.len()),
83            (
84                RiscvAirId::Global,
85                self.global_memory_finalize_events.len() +
86                    self.global_memory_initialize_events.len(),
87            ),
88        ]
89    }
90
91    fn precompile_heights(&self) -> impl Iterator<Item = (RiscvAirId, (usize, usize, usize))> {
92        self.precompile_events.events.iter().filter_map(|(code, events)| {
93            // Skip empty events.
94            (!events.is_empty()).then_some(())?;
95            let id = code.as_air_id()?;
96            Some((
97                id,
98                (
99                    events.len() * id.rows_per_event(),
100                    events.get_local_mem_events().into_iter().count(),
101                    self.global_interaction_events.len(),
102                ),
103            ))
104        })
105    }
106}