1use anyhow::Result;
2use clap::{command, Parser};
3use reqwest::Client;
4use serde::Serialize;
5use serde_json::json;
6use slack_rust::{
7 chat::post_message::{post_message, PostMessageRequest},
8 http_client::default_client,
9};
10use sp1_prover::{components::SP1ProverComponents, utils::get_cycles, SP1Prover};
11use sp1_sdk::{SP1Context, SP1Stdin};
12use sp1_stark::SP1ProverOpts;
13use std::time::{Duration, Instant};
14
15use program::load_program;
16
17use crate::program::{TesterProgram, PROGRAMS};
18
19mod program;
20
21#[derive(Parser, Clone)]
22#[command(about = "Evaluate the performance of SP1 on programs.")]
23struct EvalArgs {
24 #[arg(long, use_value_delimiter = true, value_delimiter = ',')]
27 pub programs: Vec<String>,
28
29 #[arg(long)]
31 pub shard_size: Option<usize>,
32
33 #[arg(long, default_missing_value="true", num_args=0..=1)]
35 pub post_to_slack: Option<bool>,
36
37 #[arg(long)]
39 pub slack_channel_id: Option<String>,
40
41 #[arg(long)]
43 pub slack_token: Option<String>,
44
45 #[arg(long, default_missing_value="true", num_args=0..=1)]
47 pub post_to_github: Option<bool>,
48
49 #[arg(long)]
51 pub github_token: Option<String>,
52
53 #[arg(long)]
55 pub repo_owner: Option<String>,
56
57 #[arg(long)]
59 pub repo_name: Option<String>,
60
61 #[arg(long)]
63 pub pr_number: Option<String>,
64
65 #[arg(long)]
67 pub branch_name: Option<String>,
68
69 #[arg(long)]
71 pub commit_hash: Option<String>,
72
73 #[arg(long)]
75 pub author: Option<String>,
76}
77
78pub async fn evaluate_performance<C: SP1ProverComponents>(
79 opts: SP1ProverOpts,
80) -> Result<(), Box<dyn std::error::Error>> {
81 println!("opts: {:?}", opts);
82
83 let args = EvalArgs::parse();
84
85 if let Some(shard_size) = args.shard_size {
87 std::env::set_var("SHARD_SIZE", format!("{}", 1 << shard_size));
88 }
89
90 let programs: Vec<&TesterProgram> = if args.programs.is_empty() {
92 PROGRAMS.iter().collect()
93 } else {
94 PROGRAMS
95 .iter()
96 .filter(|p| args.programs.iter().any(|arg| arg.eq_ignore_ascii_case(p.name)))
97 .collect()
98 };
99
100 sp1_sdk::utils::setup_logger();
101
102 let mut reports = Vec::new();
104 for program in &programs {
105 println!("Evaluating program: {}", program.name);
106 let (elf, stdin) = load_program(program.elf, program.input);
107 let report = run_evaluation::<C>(program.name, &elf, &stdin, opts);
108 reports.push(report);
109 println!("Finished Program: {}", program.name);
110 }
111
112 let reports_len = reports.len();
114 let success_count = reports.iter().filter(|r| r.success).count();
115 let results_text = format_results(&args, &reports);
116
117 println!("{}", results_text.join("\n"));
119
120 if args.post_to_slack.unwrap_or(false) {
122 match (&args.slack_token, &args.slack_channel_id) {
123 (Some(token), Some(channel)) => {
124 for message in &results_text {
125 post_to_slack(token, channel, message).await?;
126 }
127 }
128 _ => println!("Warning: post_to_slack is true, required Slack arguments are missing."),
129 }
130 }
131
132 if args.post_to_github.unwrap_or(false) {
134 match (&args.repo_owner, &args.repo_name, &args.pr_number, &args.github_token) {
135 (Some(owner), Some(repo), Some(pr_number), Some(token)) => {
136 let message = format_github_message(&results_text);
137 post_to_github_pr(owner, repo, pr_number, token, &message).await?;
138 }
139 _ => {
140 println!("Warning: post_to_github is true, required GitHub arguments are missing.")
141 }
142 }
143 }
144
145 let all_successful = success_count == reports_len;
147 if !all_successful {
148 println!("Some programs failed. Please check the results above.");
149 std::process::exit(1);
150 }
151
152 Ok(())
153}
154
155#[derive(Debug, Serialize)]
156pub struct PerformanceReport {
157 program: String,
158 cycles: u64,
159 exec_khz: f64,
160 core_khz: f64,
161 compressed_khz: f64,
162 time: f64,
163 success: bool,
164}
165
166fn run_evaluation<C: SP1ProverComponents>(
167 program_name: &str,
168 elf: &[u8],
169 stdin: &SP1Stdin,
170 opts: SP1ProverOpts,
171) -> PerformanceReport {
172 let cycles = get_cycles(elf, stdin);
173
174 let prover = SP1Prover::<C>::new();
175 let (_, pk_d, program, vk) = prover.setup(elf);
176
177 let context = SP1Context::default();
178
179 let (_, exec_duration) = time_operation(|| prover.execute(elf, stdin, context.clone()));
180
181 let (core_proof, core_duration) =
182 time_operation(|| prover.prove_core(&pk_d, program, stdin, opts, context).unwrap());
183
184 let (_, compress_duration) =
185 time_operation(|| prover.compress(&vk, core_proof, vec![], opts).unwrap());
186
187 let total_duration = exec_duration + core_duration + compress_duration;
188
189 PerformanceReport {
190 program: program_name.to_string(),
191 cycles,
192 exec_khz: calculate_khz(cycles, exec_duration),
193 core_khz: calculate_khz(cycles, core_duration),
194 compressed_khz: calculate_khz(cycles, compress_duration + core_duration),
195 time: total_duration.as_secs_f64(),
196 success: true,
197 }
198}
199
200fn format_results(args: &EvalArgs, results: &[PerformanceReport]) -> Vec<String> {
201 let mut detail_text = String::new();
202 if let Some(branch_name) = &args.branch_name {
203 detail_text.push_str(&format!("*Branch*: {}\n", branch_name));
204 }
205 if let Some(commit_hash) = &args.commit_hash {
206 detail_text.push_str(&format!("*Commit*: {}\n", &commit_hash[..8]));
207 }
208 if let Some(author) = &args.author {
209 detail_text.push_str(&format!("*Author*: {}\n", author));
210 }
211
212 let mut table_text = String::new();
213 table_text.push_str("```\n");
214 table_text.push_str("| program | cycles | execute (mHz) | core (kHZ) | compress (KHz) | time | success |\n");
215 table_text.push_str("|-------------------|-------------|----------------|----------------|----------------|--------|----------|");
216
217 for result in results.iter() {
218 table_text.push_str(&format!(
219 "\n| {:<17} | {:>11} | {:>14.2} | {:>14.2} | {:>14.2} | {:>6} | {:<7} |",
220 result.program,
221 result.cycles,
222 result.exec_khz / 1000.0,
223 result.core_khz,
224 result.compressed_khz,
225 format_duration(result.time),
226 if result.success { "✅" } else { "❌" }
227 ));
228 }
229 table_text.push_str("\n```");
230
231 vec!["*SP1 Performance Test Results*\n".to_string(), detail_text, table_text]
232}
233
234pub fn time_operation<T, F: FnOnce() -> T>(operation: F) -> (T, Duration) {
235 let start = Instant::now();
236 let result = operation();
237 let duration = start.elapsed();
238 (result, duration)
239}
240
241fn calculate_khz(cycles: u64, duration: Duration) -> f64 {
242 let duration_secs = duration.as_secs_f64();
243 if duration_secs > 0.0 {
244 (cycles as f64 / duration_secs) / 1_000.0
245 } else {
246 0.0
247 }
248}
249
250fn format_duration(duration: f64) -> String {
251 let secs = duration.round() as u64;
252 let minutes = secs / 60;
253 let seconds = secs % 60;
254
255 if minutes > 0 {
256 format!("{}m{}s", minutes, seconds)
257 } else if seconds > 0 {
258 format!("{}s", seconds)
259 } else {
260 format!("{}ms", (duration * 1000.0).round() as u64)
261 }
262}
263
264async fn post_to_slack(slack_token: &str, slack_channel_id: &str, message: &str) -> Result<()> {
265 let slack_api_client = default_client();
266 let request = PostMessageRequest {
267 channel: slack_channel_id.to_string(),
268 text: Some(message.to_string()),
269 ..Default::default()
270 };
271
272 post_message(&slack_api_client, &request, slack_token).await.expect("slack api call error");
273
274 Ok(())
275}
276
277fn format_github_message(results_text: &[String]) -> String {
278 let mut formatted_message = String::new();
279
280 if let Some(title) = results_text.first() {
281 formatted_message.push_str(&title.replace('*', "**"));
283 formatted_message.push('\n');
284 }
285
286 if let Some(details) = results_text.get(1) {
287 formatted_message.push_str(&details.replace('*', "**"));
289 formatted_message.push('\n');
290 }
291
292 if let Some(table) = results_text.get(2) {
293 let cleaned_table = table.trim_start_matches("```").trim_end_matches("```");
295 formatted_message.push_str(cleaned_table);
296 }
297
298 formatted_message
299}
300
301async fn post_to_github_pr(
302 owner: &str,
303 repo: &str,
304 pr_number: &str,
305 token: &str,
306 message: &str,
307) -> Result<(), Box<dyn std::error::Error>> {
308 let client = Client::new();
309 let base_url = format!("https://api.github.com/repos/{}/{}", owner, repo);
310
311 let comments_url = format!("{}/issues/{}/comments", base_url, pr_number);
313 let comments_response = client
314 .get(&comments_url)
315 .header("Authorization", format!("token {}", token))
316 .header("User-Agent", "sp1-perf-bot")
317 .send()
318 .await?;
319
320 let comments: Vec<serde_json::Value> = comments_response.json().await?;
321
322 let bot_comment = comments.iter().find(|comment| {
324 comment["user"]["login"]
325 .as_str()
326 .map(|login| login == "github-actions[bot]")
327 .unwrap_or(false)
328 });
329
330 if let Some(existing_comment) = bot_comment {
331 let comment_url = existing_comment["url"].as_str().unwrap();
333 let response = client
334 .patch(comment_url)
335 .header("Authorization", format!("token {}", token))
336 .header("User-Agent", "sp1-perf-bot")
337 .json(&json!({
338 "body": message
339 }))
340 .send()
341 .await?;
342
343 if !response.status().is_success() {
344 return Err(format!("Failed to update comment: {:?}", response.text().await?).into());
345 }
346 } else {
347 let response = client
349 .post(&comments_url)
350 .header("Authorization", format!("token {}", token))
351 .header("User-Agent", "sp1-perf-bot")
352 .json(&json!({
353 "body": message
354 }))
355 .send()
356 .await?;
357
358 if !response.status().is_success() {
359 return Err(format!("Failed to post comment: {:?}", response.text().await?).into());
360 }
361 }
362
363 Ok(())
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[test]
371 fn test_format_results() {
372 let dummy_reports = vec![
373 PerformanceReport {
374 program: "fibonacci".to_string(),
375 cycles: 11291,
376 exec_khz: 29290.0,
377 core_khz: 30.0,
378 compressed_khz: 0.1,
379 time: 622.385,
380 success: true,
381 },
382 PerformanceReport {
383 program: "super-program".to_string(),
384 cycles: 275735600,
385 exec_khz: 70190.0,
386 core_khz: 310.0,
387 compressed_khz: 120.0,
388 time: 812.285,
389 success: true,
390 },
391 ];
392
393 let args = EvalArgs {
394 programs: vec!["fibonacci".to_string(), "super-program".to_string()],
395 shard_size: None,
396 post_to_slack: Some(false),
397 slack_channel_id: None,
398 slack_token: None,
399 post_to_github: Some(true),
400 github_token: Some("abcdef1234567890".to_string()),
401 repo_owner: Some("succinctlabs".to_string()),
402 repo_name: Some("sp1".to_string()),
403 pr_number: Some("123456".to_string()),
404 branch_name: Some("feature-branch".to_string()),
405 commit_hash: Some("abcdef1234567890".to_string()),
406 author: Some("John Doe".to_string()),
407 };
408
409 let formatted_results = format_results(&args, &dummy_reports);
410
411 for line in &formatted_results {
412 println!("{}", line);
413 }
414
415 assert_eq!(formatted_results.len(), 3);
416 assert!(formatted_results[0].contains("SP1 Performance Test Results"));
417 assert!(formatted_results[1].contains("*Branch*: feature-branch"));
418 assert!(formatted_results[1].contains("*Commit*: abcdef12"));
419 assert!(formatted_results[1].contains("*Author*: John Doe"));
420 assert!(formatted_results[2].contains("fibonacci"));
421 assert!(formatted_results[2].contains("super-program"));
422 }
423}