sp1_eval/
lib.rs

1use anyhow::Result;
2use clap::{command, Parser};
3use reqwest::Client;
4use serde::Serialize;
5use serde_json::json;
6use slack_rust::{
7    chat::post_message::{post_message, PostMessageRequest},
8    http_client::default_client,
9};
10use sp1_prover::{components::SP1ProverComponents, utils::get_cycles, SP1Prover};
11use sp1_sdk::{SP1Context, SP1Stdin};
12use sp1_stark::SP1ProverOpts;
13use std::time::{Duration, Instant};
14
15use program::load_program;
16
17use crate::program::{TesterProgram, PROGRAMS};
18
19mod program;
20
21#[derive(Parser, Clone)]
22#[command(about = "Evaluate the performance of SP1 on programs.")]
23struct EvalArgs {
24    /// The programs to evaluate, specified by name. If not specified, all programs will be
25    /// evaluated.
26    #[arg(long, use_value_delimiter = true, value_delimiter = ',')]
27    pub programs: Vec<String>,
28
29    /// The shard size to use for the prover.
30    #[arg(long)]
31    pub shard_size: Option<usize>,
32
33    /// Whether to post results to Slack.
34    #[arg(long, default_missing_value="true", num_args=0..=1)]
35    pub post_to_slack: Option<bool>,
36
37    /// The Slack channel ID to post results to, only used if post_to_slack is true.
38    #[arg(long)]
39    pub slack_channel_id: Option<String>,
40
41    /// The Slack bot token to post results to, only used if post_to_slack is true.
42    #[arg(long)]
43    pub slack_token: Option<String>,
44
45    /// Whether to post results to GitHub PR.
46    #[arg(long, default_missing_value="true", num_args=0..=1)]
47    pub post_to_github: Option<bool>,
48
49    /// The GitHub token for authentication, only used if post_to_github is true.
50    #[arg(long)]
51    pub github_token: Option<String>,
52
53    /// The GitHub repository owner.
54    #[arg(long)]
55    pub repo_owner: Option<String>,
56
57    /// The GitHub repository name.
58    #[arg(long)]
59    pub repo_name: Option<String>,
60
61    /// The GitHub PR number.
62    #[arg(long)]
63    pub pr_number: Option<String>,
64
65    /// The name of the branch.
66    #[arg(long)]
67    pub branch_name: Option<String>,
68
69    /// The commit hash.
70    #[arg(long)]
71    pub commit_hash: Option<String>,
72
73    /// The author of the commit.
74    #[arg(long)]
75    pub author: Option<String>,
76}
77
78pub async fn evaluate_performance<C: SP1ProverComponents>(
79    opts: SP1ProverOpts,
80) -> Result<(), Box<dyn std::error::Error>> {
81    println!("opts: {opts:?}");
82
83    let args = EvalArgs::parse();
84
85    // Set environment variables to configure the prover.
86    if let Some(shard_size) = args.shard_size {
87        std::env::set_var("SHARD_SIZE", format!("{}", 1 << shard_size));
88    }
89
90    // Choose which programs to evaluate.
91    let programs: Vec<&TesterProgram> = if args.programs.is_empty() {
92        PROGRAMS.iter().collect()
93    } else {
94        PROGRAMS
95            .iter()
96            .filter(|p| args.programs.iter().any(|arg| arg.eq_ignore_ascii_case(p.name)))
97            .collect()
98    };
99
100    sp1_sdk::utils::setup_logger();
101
102    // Run the evaluations on each program.
103    let mut reports = Vec::new();
104    for program in &programs {
105        println!("Evaluating program: {}", program.name);
106        let (elf, stdin) = load_program(program.elf, program.input);
107        let report = run_evaluation::<C>(program.name, &elf, &stdin, opts);
108        reports.push(report);
109        println!("Finished Program: {}", program.name);
110    }
111
112    // Prepare and format the results.
113    let reports_len = reports.len();
114    let success_count = reports.iter().filter(|r| r.success).count();
115    let results_text = format_results(&args, &reports);
116
117    // Print results
118    println!("{}", results_text.join("\n"));
119
120    // Post to Slack if applicable
121    if args.post_to_slack.unwrap_or(false) {
122        match (&args.slack_token, &args.slack_channel_id) {
123            (Some(token), Some(channel)) => {
124                for message in &results_text {
125                    post_to_slack(token, channel, message).await?;
126                }
127            }
128            _ => println!("Warning: post_to_slack is true, required Slack arguments are missing."),
129        }
130    }
131
132    // Post to GitHub PR if applicable
133    if args.post_to_github.unwrap_or(false) {
134        match (&args.repo_owner, &args.repo_name, &args.pr_number, &args.github_token) {
135            (Some(owner), Some(repo), Some(pr_number), Some(token)) => {
136                let message = format_github_message(&results_text);
137                post_to_github_pr(owner, repo, pr_number, token, &message).await?;
138            }
139            _ => {
140                println!("Warning: post_to_github is true, required GitHub arguments are missing.")
141            }
142        }
143    }
144
145    // Exit with an error if any programs failed.
146    let all_successful = success_count == reports_len;
147    if !all_successful {
148        println!("Some programs failed. Please check the results above.");
149        std::process::exit(1);
150    }
151
152    Ok(())
153}
154
155#[derive(Debug, Serialize)]
156pub struct PerformanceReport {
157    program: String,
158    cycles: u64,
159    exec_khz: f64,
160    core_khz: f64,
161    compressed_khz: f64,
162    time: f64,
163    success: bool,
164}
165
166fn run_evaluation<C: SP1ProverComponents>(
167    program_name: &str,
168    elf: &[u8],
169    stdin: &SP1Stdin,
170    opts: SP1ProverOpts,
171) -> PerformanceReport {
172    let cycles = get_cycles(elf, stdin);
173
174    let prover = SP1Prover::<C>::new();
175    let (_, pk_d, program, vk) = prover.setup(elf);
176
177    let context = SP1Context::default();
178
179    let (_, exec_duration) = time_operation(|| prover.execute(elf, stdin, context.clone()));
180
181    let (core_proof, core_duration) =
182        time_operation(|| prover.prove_core(&pk_d, program, stdin, opts, context).unwrap());
183
184    let (_, compress_duration) =
185        time_operation(|| prover.compress(&vk, core_proof, vec![], opts).unwrap());
186
187    let total_duration = exec_duration + core_duration + compress_duration;
188
189    PerformanceReport {
190        program: program_name.to_string(),
191        cycles,
192        exec_khz: calculate_khz(cycles, exec_duration),
193        core_khz: calculate_khz(cycles, core_duration),
194        compressed_khz: calculate_khz(cycles, compress_duration + core_duration),
195        time: total_duration.as_secs_f64(),
196        success: true,
197    }
198}
199
200fn format_results(args: &EvalArgs, results: &[PerformanceReport]) -> Vec<String> {
201    let mut detail_text = String::new();
202    if let Some(branch_name) = &args.branch_name {
203        detail_text.push_str(&format!("*Branch*: {branch_name}\n"));
204    }
205    if let Some(commit_hash) = &args.commit_hash {
206        detail_text.push_str(&format!("*Commit*: {}\n", &commit_hash[..8]));
207    }
208    if let Some(author) = &args.author {
209        detail_text.push_str(&format!("*Author*: {author}\n"));
210    }
211
212    let mut table_text = String::new();
213    table_text.push_str("```\n");
214    table_text.push_str("| program           | cycles      | execute (mHz)  | core (kHZ)     | compress (KHz) | time   | success  |\n");
215    table_text.push_str("|-------------------|-------------|----------------|----------------|----------------|--------|----------|");
216
217    for result in results.iter() {
218        table_text.push_str(&format!(
219            "\n| {:<17} | {:>11} | {:>14.2} | {:>14.2} | {:>14.2} | {:>6} | {:<7} |",
220            result.program,
221            result.cycles,
222            result.exec_khz / 1000.0,
223            result.core_khz,
224            result.compressed_khz,
225            format_duration(result.time),
226            if result.success { "✅" } else { "❌" }
227        ));
228    }
229    table_text.push_str("\n```");
230
231    vec!["*SP1 Performance Test Results*\n".to_string(), detail_text, table_text]
232}
233
234pub fn time_operation<T, F: FnOnce() -> T>(operation: F) -> (T, Duration) {
235    let start = Instant::now();
236    let result = operation();
237    let duration = start.elapsed();
238    (result, duration)
239}
240
241fn calculate_khz(cycles: u64, duration: Duration) -> f64 {
242    let duration_secs = duration.as_secs_f64();
243    if duration_secs > 0.0 {
244        (cycles as f64 / duration_secs) / 1_000.0
245    } else {
246        0.0
247    }
248}
249
250fn format_duration(duration: f64) -> String {
251    let secs = duration.round() as u64;
252    let minutes = secs / 60;
253    let seconds = secs % 60;
254
255    if minutes > 0 {
256        format!("{minutes}m{seconds}s")
257    } else if seconds > 0 {
258        format!("{seconds}s")
259    } else {
260        format!("{}ms", (duration * 1000.0).round() as u64)
261    }
262}
263
264async fn post_to_slack(slack_token: &str, slack_channel_id: &str, message: &str) -> Result<()> {
265    let slack_api_client = default_client();
266    let request = PostMessageRequest {
267        channel: slack_channel_id.to_string(),
268        text: Some(message.to_string()),
269        ..Default::default()
270    };
271
272    post_message(&slack_api_client, &request, slack_token).await.expect("slack api call error");
273
274    Ok(())
275}
276
277fn format_github_message(results_text: &[String]) -> String {
278    let mut formatted_message = String::new();
279
280    if let Some(title) = results_text.first() {
281        // Add an extra asterisk for GitHub bold formatting
282        formatted_message.push_str(&title.replace('*', "**"));
283        formatted_message.push('\n');
284    }
285
286    if let Some(details) = results_text.get(1) {
287        // Add an extra asterisk for GitHub bold formatting
288        formatted_message.push_str(&details.replace('*', "**"));
289        formatted_message.push('\n');
290    }
291
292    if let Some(table) = results_text.get(2) {
293        // Remove the triple backticks as GitHub doesn't require them for table formatting
294        let cleaned_table = table.trim_start_matches("```").trim_end_matches("```");
295        formatted_message.push_str(cleaned_table);
296    }
297
298    formatted_message
299}
300
301async fn post_to_github_pr(
302    owner: &str,
303    repo: &str,
304    pr_number: &str,
305    token: &str,
306    message: &str,
307) -> Result<(), Box<dyn std::error::Error>> {
308    let client = Client::new();
309    let base_url = format!("https://api.github.com/repos/{owner}/{repo}");
310
311    // Get all comments on the PR
312    let comments_url = format!("{base_url}/issues/{pr_number}/comments");
313    let comments_response = client
314        .get(&comments_url)
315        .header("Authorization", format!("token {token}"))
316        .header("User-Agent", "sp1-perf-bot")
317        .send()
318        .await?;
319
320    let comments: Vec<serde_json::Value> = comments_response.json().await?;
321
322    // Look for an existing comment from our bot
323    let bot_comment = comments.iter().find(|comment| {
324        comment["user"]["login"]
325            .as_str()
326            .map(|login| login == "github-actions[bot]")
327            .unwrap_or(false)
328    });
329
330    if let Some(existing_comment) = bot_comment {
331        // Update the existing comment
332        let comment_url = existing_comment["url"].as_str().unwrap();
333        let response = client
334            .patch(comment_url)
335            .header("Authorization", format!("token {token}"))
336            .header("User-Agent", "sp1-perf-bot")
337            .json(&json!({
338                "body": message
339            }))
340            .send()
341            .await?;
342
343        if !response.status().is_success() {
344            return Err(format!("Failed to update comment: {:?}", response.text().await?).into());
345        }
346    } else {
347        // Create a new comment
348        let response = client
349            .post(&comments_url)
350            .header("Authorization", format!("token {token}"))
351            .header("User-Agent", "sp1-perf-bot")
352            .json(&json!({
353                "body": message
354            }))
355            .send()
356            .await?;
357
358        if !response.status().is_success() {
359            return Err(format!("Failed to post comment: {:?}", response.text().await?).into());
360        }
361    }
362
363    Ok(())
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[test]
371    fn test_format_results() {
372        let dummy_reports = vec![
373            PerformanceReport {
374                program: "fibonacci".to_string(),
375                cycles: 11291,
376                exec_khz: 29290.0,
377                core_khz: 30.0,
378                compressed_khz: 0.1,
379                time: 622.385,
380                success: true,
381            },
382            PerformanceReport {
383                program: "super-program".to_string(),
384                cycles: 275735600,
385                exec_khz: 70190.0,
386                core_khz: 310.0,
387                compressed_khz: 120.0,
388                time: 812.285,
389                success: true,
390            },
391        ];
392
393        let args = EvalArgs {
394            programs: vec!["fibonacci".to_string(), "super-program".to_string()],
395            shard_size: None,
396            post_to_slack: Some(false),
397            slack_channel_id: None,
398            slack_token: None,
399            post_to_github: Some(true),
400            github_token: Some("abcdef1234567890".to_string()),
401            repo_owner: Some("succinctlabs".to_string()),
402            repo_name: Some("sp1".to_string()),
403            pr_number: Some("123456".to_string()),
404            branch_name: Some("feature-branch".to_string()),
405            commit_hash: Some("abcdef1234567890".to_string()),
406            author: Some("John Doe".to_string()),
407        };
408
409        let formatted_results = format_results(&args, &dummy_reports);
410
411        for line in &formatted_results {
412            println!("{line}");
413        }
414
415        assert_eq!(formatted_results.len(), 3);
416        assert!(formatted_results[0].contains("SP1 Performance Test Results"));
417        assert!(formatted_results[1].contains("*Branch*: feature-branch"));
418        assert!(formatted_results[1].contains("*Commit*: abcdef12"));
419        assert!(formatted_results[1].contains("*Author*: John Doe"));
420        assert!(formatted_results[2].contains("fibonacci"));
421        assert!(formatted_results[2].contains("super-program"));
422    }
423}