1use std::env;
2
3use serde::{Deserialize, Serialize};
4use sysinfo::System;
5
6const MAX_SHARD_SIZE: usize = 1 << 21;
7const RECURSION_MAX_SHARD_SIZE: usize = 1 << 22;
8const MAX_SHARD_BATCH_SIZE: usize = 8;
9const DEFAULT_TRACE_GEN_WORKERS: usize = 1;
10const DEFAULT_CHECKPOINTS_CHANNEL_CAPACITY: usize = 128;
11const DEFAULT_RECORDS_AND_TRACES_CHANNEL_CAPACITY: usize = 1;
12const MAX_DEFERRED_SPLIT_THRESHOLD: usize = 1 << 15;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16pub struct SP1ProverOpts {
17 pub core_opts: SP1CoreOpts,
19 pub recursion_opts: SP1CoreOpts,
21}
22
23impl SP1ProverOpts {
24 #[must_use]
26 pub fn auto() -> Self {
27 let cpu_ram_gb = System::new_all().total_memory() / (1024 * 1024 * 1024);
28 SP1ProverOpts::cpu(cpu_ram_gb as usize)
29 }
30
31 #[must_use]
34 fn get_memory_opts(cpu_ram_gb: usize) -> (usize, usize, usize) {
35 match cpu_ram_gb {
36 0..33 => (19, 1, 3),
37 33..49 => (20, 1, 2),
38 49..65 => (21, 1, 3),
39 65..81 => (21, 3, 1),
40 81.. => (21, 4, 1),
41 }
42 }
43
44 #[must_use]
48 pub fn cpu(cpu_ram_gb: usize) -> Self {
49 let (log2_shard_size, shard_batch_size, log2_divisor) = Self::get_memory_opts(cpu_ram_gb);
50
51 let mut opts = SP1ProverOpts::default();
52 opts.core_opts.shard_size = 1 << log2_shard_size;
53 opts.core_opts.shard_batch_size = shard_batch_size;
54
55 opts.core_opts.records_and_traces_channel_capacity = 1;
56 opts.core_opts.trace_gen_workers = 1;
57
58 let divisor = 1 << log2_divisor;
59 opts.core_opts.split_opts.deferred /= divisor;
60 opts.core_opts.split_opts.keccak /= divisor;
61 opts.core_opts.split_opts.sha_extend /= divisor;
62 opts.core_opts.split_opts.sha_compress /= divisor;
63 opts.core_opts.split_opts.memory /= divisor;
64
65 opts.recursion_opts.shard_batch_size = 2;
66 opts.recursion_opts.records_and_traces_channel_capacity = 1;
67 opts.recursion_opts.trace_gen_workers = 1;
68
69 opts
70 }
71
72 #[must_use]
74 pub fn gpu(cpu_ram_gb: usize, gpu_ram_gb: usize) -> Self {
75 let mut opts = SP1ProverOpts::default();
76
77 if 24 <= gpu_ram_gb {
79 let log2_shard_size = 21;
80 opts.core_opts.shard_size = 1 << log2_shard_size;
81 opts.core_opts.shard_batch_size = 1;
82
83 let log2_deferred_threshold = 14;
84 opts.core_opts.split_opts = SplitOpts::new(1 << log2_deferred_threshold);
85
86 opts.core_opts.records_and_traces_channel_capacity = 4;
87 opts.core_opts.trace_gen_workers = 4;
88
89 if cpu_ram_gb <= 20 {
90 opts.core_opts.records_and_traces_channel_capacity = 1;
91 opts.core_opts.trace_gen_workers = 2;
92 }
93 } else {
94 unreachable!("not enough gpu memory");
95 }
96
97 opts.recursion_opts.shard_batch_size = 1;
99
100 opts
101 }
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
106pub struct SP1CoreOpts {
107 pub shard_size: usize,
109 pub shard_batch_size: usize,
111 pub split_opts: SplitOpts,
113 pub trace_gen_workers: usize,
115 pub checkpoints_channel_capacity: usize,
117 pub records_and_traces_channel_capacity: usize,
119}
120
121impl Default for SP1ProverOpts {
122 fn default() -> Self {
123 Self { core_opts: SP1CoreOpts::default(), recursion_opts: SP1CoreOpts::recursion() }
124 }
125}
126
127impl Default for SP1CoreOpts {
128 fn default() -> Self {
129 let cpu_ram_gb = System::new_all().total_memory() / (1024 * 1024 * 1024);
130 let (default_log2_shard_size, default_shard_batch_size, default_log2_divisor) =
131 SP1ProverOpts::get_memory_opts(cpu_ram_gb as usize);
132
133 let mut opts = Self {
134 shard_size: env::var("SHARD_SIZE").map_or_else(
135 |_| 1 << default_log2_shard_size,
136 |s| s.parse::<usize>().unwrap_or(1 << default_log2_shard_size),
137 ),
138 shard_batch_size: env::var("SHARD_BATCH_SIZE").map_or_else(
139 |_| default_shard_batch_size,
140 |s| s.parse::<usize>().unwrap_or(default_shard_batch_size),
141 ),
142 split_opts: SplitOpts::new(MAX_DEFERRED_SPLIT_THRESHOLD),
143 trace_gen_workers: env::var("TRACE_GEN_WORKERS").map_or_else(
144 |_| DEFAULT_TRACE_GEN_WORKERS,
145 |s| s.parse::<usize>().unwrap_or(DEFAULT_TRACE_GEN_WORKERS),
146 ),
147 checkpoints_channel_capacity: env::var("CHECKPOINTS_CHANNEL_CAPACITY").map_or_else(
148 |_| DEFAULT_CHECKPOINTS_CHANNEL_CAPACITY,
149 |s| s.parse::<usize>().unwrap_or(DEFAULT_CHECKPOINTS_CHANNEL_CAPACITY),
150 ),
151 records_and_traces_channel_capacity: env::var("RECORDS_AND_TRACES_CHANNEL_CAPACITY")
152 .map_or_else(
153 |_| DEFAULT_RECORDS_AND_TRACES_CHANNEL_CAPACITY,
154 |s| s.parse::<usize>().unwrap_or(DEFAULT_RECORDS_AND_TRACES_CHANNEL_CAPACITY),
155 ),
156 };
157
158 let divisor = 1 << default_log2_divisor;
159 opts.split_opts.deferred /= divisor;
160 opts.split_opts.keccak /= divisor;
161 opts.split_opts.sha_extend /= divisor;
162 opts.split_opts.sha_compress /= divisor;
163 opts.split_opts.memory /= divisor;
164
165 opts
166 }
167}
168
169impl SP1CoreOpts {
170 #[must_use]
172 pub fn recursion() -> Self {
173 let mut opts = Self::max();
174 opts.shard_size = RECURSION_MAX_SHARD_SIZE;
175 opts.shard_batch_size = 2;
176 opts
177 }
178
179 #[must_use]
181 pub fn max() -> Self {
182 let split_threshold = env::var("SPLIT_THRESHOLD")
183 .map(|s| s.parse::<usize>().unwrap_or(MAX_DEFERRED_SPLIT_THRESHOLD))
184 .unwrap_or(MAX_DEFERRED_SPLIT_THRESHOLD)
185 .max(MAX_DEFERRED_SPLIT_THRESHOLD);
186
187 let shard_size = env::var("SHARD_SIZE")
188 .map_or_else(|_| MAX_SHARD_SIZE, |s| s.parse::<usize>().unwrap_or(MAX_SHARD_SIZE));
189
190 Self {
191 shard_size,
192 shard_batch_size: env::var("SHARD_BATCH_SIZE").map_or_else(
193 |_| MAX_SHARD_BATCH_SIZE,
194 |s| s.parse::<usize>().unwrap_or(MAX_SHARD_BATCH_SIZE),
195 ),
196 split_opts: SplitOpts::new(split_threshold),
197 trace_gen_workers: env::var("TRACE_GEN_WORKERS").map_or_else(
198 |_| DEFAULT_TRACE_GEN_WORKERS,
199 |s| s.parse::<usize>().unwrap_or(DEFAULT_TRACE_GEN_WORKERS),
200 ),
201 checkpoints_channel_capacity: env::var("CHECKPOINTS_CHANNEL_CAPACITY").map_or_else(
202 |_| DEFAULT_CHECKPOINTS_CHANNEL_CAPACITY,
203 |s| s.parse::<usize>().unwrap_or(DEFAULT_CHECKPOINTS_CHANNEL_CAPACITY),
204 ),
205 records_and_traces_channel_capacity: env::var("RECORDS_AND_TRACES_CHANNEL_CAPACITY")
206 .map_or_else(
207 |_| DEFAULT_RECORDS_AND_TRACES_CHANNEL_CAPACITY,
208 |s| s.parse::<usize>().unwrap_or(DEFAULT_RECORDS_AND_TRACES_CHANNEL_CAPACITY),
209 ),
210 }
211 }
212}
213
214#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
216pub struct SplitOpts {
217 pub combine_memory_threshold: usize,
220 pub deferred: usize,
222 pub keccak: usize,
224 pub sha_extend: usize,
226 pub sha_compress: usize,
228 pub memory: usize,
230}
231
232impl SplitOpts {
233 #[must_use]
238 pub fn new(deferred_split_threshold: usize) -> Self {
239 Self {
240 combine_memory_threshold: 1 << 17,
241 deferred: deferred_split_threshold,
242 keccak: 8 * deferred_split_threshold / 24,
243 sha_extend: 32 * deferred_split_threshold / 48,
244 sha_compress: 32 * deferred_split_threshold / 80,
245 memory: 64 * deferred_split_threshold,
246 }
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 #![allow(clippy::print_stdout)]
253
254 use super::*;
255
256 #[test]
257 fn test_opts() {
258 let opts = SP1ProverOpts::cpu(8);
259 println!("8: {:?}", opts.core_opts);
260
261 let opts = SP1ProverOpts::cpu(15);
262 println!("15: {:?}", opts.core_opts);
263
264 let opts = SP1ProverOpts::cpu(16);
265 println!("16: {:?}", opts.core_opts);
266
267 let opts = SP1ProverOpts::cpu(32);
268 println!("32: {:?}", opts.core_opts);
269
270 let opts = SP1ProverOpts::cpu(36);
271 println!("36: {:?}", opts.core_opts);
272
273 let opts = SP1ProverOpts::cpu(64);
274 println!("64: {:?}", opts.core_opts);
275
276 let opts = SP1ProverOpts::cpu(128);
277 println!("128: {:?}", opts.core_opts);
278
279 let opts = SP1ProverOpts::cpu(256);
280 println!("256: {:?}", opts.core_opts);
281
282 let opts = SP1ProverOpts::cpu(512);
283 println!("512: {:?}", opts.core_opts);
284
285 let opts = SP1ProverOpts::auto();
286 println!("auto: {:?}", opts.core_opts);
287 }
288}