sp1_core_machine/shape/
shapeable.rs1use hashbrown::HashMap;
2use itertools::Itertools;
3
4use sp1_core_executor::{events::PrecompileLocalMemory, ExecutionRecord, RiscvAirId};
5use sp1_stark::MachineRecord;
6
7use crate::memory::NUM_LOCAL_MEMORY_ENTRIES_PER_ROW;
8
9#[derive(Debug, Clone, Copy)]
10pub enum ShardKind {
11 PackedCore,
12 Core,
13 GlobalMemory,
14 Precompile,
15}
16
17pub trait Shapeable {
18 fn kind(&self) -> ShardKind;
19 fn shard(&self) -> u32;
20 fn log2_shard_size(&self) -> usize;
21 fn debug_stats(&self) -> HashMap<String, usize>;
22 fn core_heights(&self) -> Vec<(RiscvAirId, usize)>;
23 fn memory_heights(&self) -> Vec<(RiscvAirId, usize)>;
24 fn precompile_heights(&self) -> impl Iterator<Item = (RiscvAirId, (usize, usize, usize))>;
27}
28
29impl Shapeable for ExecutionRecord {
30 fn kind(&self) -> ShardKind {
31 let contains_global_memory = !self.global_memory_initialize_events.is_empty() ||
32 !self.global_memory_finalize_events.is_empty();
33 match (self.contains_cpu(), contains_global_memory) {
34 (true, true) => ShardKind::PackedCore,
35 (true, false) => ShardKind::Core,
36 (false, true) => ShardKind::GlobalMemory,
37 (false, false) => ShardKind::Precompile,
38 }
39 }
40 fn shard(&self) -> u32 {
41 self.public_values.shard
42 }
43
44 fn log2_shard_size(&self) -> usize {
45 self.cpu_events.len().next_power_of_two().ilog2() as usize
46 }
47
48 fn debug_stats(&self) -> HashMap<String, usize> {
49 self.stats()
50 }
51
52 fn core_heights(&self) -> Vec<(RiscvAirId, usize)> {
53 vec![
54 (RiscvAirId::Cpu, self.cpu_events.len()),
55 (RiscvAirId::DivRem, self.divrem_events.len()),
56 (RiscvAirId::AddSub, self.add_events.len() + self.sub_events.len()),
57 (RiscvAirId::Bitwise, self.bitwise_events.len()),
58 (RiscvAirId::Mul, self.mul_events.len()),
59 (RiscvAirId::ShiftRight, self.shift_right_events.len()),
60 (RiscvAirId::ShiftLeft, self.shift_left_events.len()),
61 (RiscvAirId::Lt, self.lt_events.len()),
62 (
63 RiscvAirId::MemoryLocal,
64 self.get_local_mem_events()
65 .chunks(NUM_LOCAL_MEMORY_ENTRIES_PER_ROW)
66 .into_iter()
67 .count(),
68 ),
69 (RiscvAirId::MemoryInstrs, self.memory_instr_events.len()),
70 (RiscvAirId::Auipc, self.auipc_events.len()),
71 (RiscvAirId::Branch, self.branch_events.len()),
72 (RiscvAirId::Jump, self.jump_events.len()),
73 (RiscvAirId::Global, self.global_interaction_events.len()),
74 (RiscvAirId::SyscallCore, self.syscall_events.len()),
75 (RiscvAirId::SyscallInstrs, self.syscall_events.len()),
76 ]
77 }
78
79 fn memory_heights(&self) -> Vec<(RiscvAirId, usize)> {
80 vec![
81 (RiscvAirId::MemoryGlobalInit, self.global_memory_initialize_events.len()),
82 (RiscvAirId::MemoryGlobalFinalize, self.global_memory_finalize_events.len()),
83 (
84 RiscvAirId::Global,
85 self.global_memory_finalize_events.len() +
86 self.global_memory_initialize_events.len(),
87 ),
88 ]
89 }
90
91 fn precompile_heights(&self) -> impl Iterator<Item = (RiscvAirId, (usize, usize, usize))> {
92 self.precompile_events.events.iter().filter_map(|(code, events)| {
93 (!events.is_empty()).then_some(())?;
95 let id = code.as_air_id()?;
96 Some((
97 id,
98 (
99 events.len() * id.rows_per_event(),
100 events.get_local_mem_events().into_iter().count(),
101 self.global_interaction_events.len(),
102 ),
103 ))
104 })
105 }
106}