sp1_sdk/
install.rs

1//! # SP1 Install
2//!
3//! A library for installing the SP1 circuit artifacts.
4
5use cfg_if::cfg_if;
6use std::path::PathBuf;
7
8#[cfg(any(feature = "network", feature = "network"))]
9use {
10    crate::utils::block_on,
11    futures::StreamExt,
12    indicatif::{ProgressBar, ProgressStyle},
13    reqwest::Client,
14    std::{cmp::min, process::Command},
15};
16
17use crate::SP1_CIRCUIT_VERSION;
18
19/// The base URL for the S3 bucket containing the circuit artifacts.
20pub const CIRCUIT_ARTIFACTS_URL_BASE: &str = "https://sp1-circuits.s3-us-east-2.amazonaws.com";
21
22/// The directory where the groth16 circuit artifacts will be stored.
23#[must_use]
24pub fn groth16_circuit_artifacts_dir() -> PathBuf {
25    std::env::var("SP1_GROTH16_CIRCUIT_PATH")
26        .map_or_else(
27            |_| dirs::home_dir().unwrap().join(".sp1").join("circuits/groth16"),
28            |path| path.parse().unwrap(),
29        )
30        .join(SP1_CIRCUIT_VERSION)
31}
32
33/// The directory where the plonk circuit artifacts will be stored.
34#[must_use]
35pub fn plonk_circuit_artifacts_dir() -> PathBuf {
36    std::env::var("SP1_PLONK_CIRCUIT_PATH")
37        .map_or_else(
38            |_| dirs::home_dir().unwrap().join(".sp1").join("circuits/plonk"),
39            |path| path.parse().unwrap(),
40        )
41        .join(SP1_CIRCUIT_VERSION)
42}
43
44/// Tries to install the groth16 circuit artifacts if they are not already installed.
45#[must_use]
46pub fn try_install_circuit_artifacts(artifacts_type: &str) -> PathBuf {
47    let build_dir = if artifacts_type == "groth16" {
48        groth16_circuit_artifacts_dir()
49    } else if artifacts_type == "plonk" {
50        plonk_circuit_artifacts_dir()
51    } else {
52        unimplemented!("unsupported artifacts type: {}", artifacts_type);
53    };
54
55    if build_dir.exists() {
56        eprintln!(
57            "[sp1] {} circuit artifacts already seem to exist at {}. if you want to re-download them, delete the directory",
58            artifacts_type,
59            build_dir.display()
60        );
61    } else {
62        cfg_if! {
63            if #[cfg(any(feature = "network", feature = "network"))] {
64                eprintln!(
65                    "[sp1] {} circuit artifacts for version {} do not exist at {}. downloading...",
66                    artifacts_type,
67                    SP1_CIRCUIT_VERSION,
68                    build_dir.display()
69                );
70                install_circuit_artifacts(build_dir.clone(), artifacts_type);
71            }
72        }
73    }
74    build_dir
75}
76
77/// Install the latest circuit artifacts.
78///
79/// This function will download the latest circuit artifacts from the S3 bucket and extract them
80/// to the directory specified by [`groth16_bn254_artifacts_dir()`].
81#[cfg(any(feature = "network", feature = "network"))]
82#[allow(clippy::needless_pass_by_value)]
83pub fn install_circuit_artifacts(build_dir: PathBuf, artifacts_type: &str) {
84    // Create the build directory.
85    std::fs::create_dir_all(&build_dir).expect("failed to create build directory");
86
87    // Download the artifacts.
88    let download_url =
89        format!("{CIRCUIT_ARTIFACTS_URL_BASE}/{SP1_CIRCUIT_VERSION}-{artifacts_type}.tar.gz");
90    let mut artifacts_tar_gz_file =
91        tempfile::NamedTempFile::new().expect("failed to create tempfile");
92    let client = Client::builder().build().expect("failed to create reqwest client");
93    block_on(download_file(&client, &download_url, &mut artifacts_tar_gz_file))
94        .expect("failed to download file");
95
96    // Extract the tarball to the build directory.
97    let mut res = Command::new("tar")
98        .args([
99            "-Pxzf",
100            artifacts_tar_gz_file.path().to_str().unwrap(),
101            "-C",
102            build_dir.to_str().unwrap(),
103        ])
104        .spawn()
105        .expect("failed to extract tarball");
106    res.wait().unwrap();
107
108    eprintln!("[sp1] downloaded {} to {:?}", download_url, build_dir.to_str().unwrap(),);
109}
110
111/// Download the file with a progress bar that indicates the progress.
112#[cfg(any(feature = "network", feature = "network"))]
113pub async fn download_file(
114    client: &Client,
115    url: &str,
116    file: &mut impl std::io::Write,
117) -> std::result::Result<(), String> {
118    let res = client.get(url).send().await.or(Err(format!("Failed to GET from '{}'", &url)))?;
119
120    let total_size =
121        res.content_length().ok_or(format!("Failed to get content length from '{}'", &url))?;
122
123    let pb = ProgressBar::new(total_size);
124    pb.set_style(ProgressStyle::default_bar()
125        .template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})").unwrap()
126        .progress_chars("#>-"));
127
128    let mut downloaded: u64 = 0;
129    let mut stream = res.bytes_stream();
130    while let Some(item) = stream.next().await {
131        let chunk = item.or(Err("Error while downloading file"))?;
132        file.write_all(&chunk).or(Err("Error while writing to file"))?;
133        let new = min(downloaded + (chunk.len() as u64), total_size);
134        downloaded = new;
135        pb.set_position(new);
136    }
137    pb.finish();
138
139    Ok(())
140}