Skip to main content

sp1_cli/commands/
install_toolchain.rs

1use std::{
2    fs::{self},
3    io::Read,
4    process::Command,
5};
6
7use anyhow::Result;
8use clap::Parser;
9use dirs::home_dir;
10use indicatif::{ProgressBar, ProgressStyle};
11use rand::{distributions::Alphanumeric, Rng};
12use reqwest::Client;
13
14#[cfg(target_family = "unix")]
15use std::os::unix::fs::PermissionsExt;
16
17use crate::{
18    get_target, get_toolchain_download_url, is_supported_target, url_exists, RUSTUP_TOOLCHAIN_NAME,
19};
20
21#[derive(Parser)]
22#[command(name = "install-toolchain", about = "Install the cargo-prove toolchain.")]
23pub struct InstallToolchainCmd {
24    #[arg(short, long, env = "GITHUB_TOKEN")]
25    pub token: Option<String>,
26}
27
28impl InstallToolchainCmd {
29    #[allow(clippy::uninlined_format_args)]
30    pub async fn run(&self) -> Result<()> {
31        // Check if rust is installed.
32        if Command::new("rustup")
33            .arg("--version")
34            .stdout(std::process::Stdio::null())
35            .stderr(std::process::Stdio::null())
36            .status()
37            .is_err()
38        {
39            return Err(anyhow::anyhow!(
40                "Rust is not installed. Please install Rust from https://rustup.rs/ and try again."
41            ));
42        }
43
44        // Setup client with optional token.
45        let client_builder = Client::builder().user_agent("Mozilla/5.0");
46        let client = if let Some(ref token) = self.token {
47            client_builder
48                .default_headers({
49                    let mut headers = reqwest::header::HeaderMap::new();
50                    headers.insert(
51                        reqwest::header::AUTHORIZATION,
52                        reqwest::header::HeaderValue::from_str(&format!("token {token}")).unwrap(),
53                    );
54                    headers
55                })
56                .build()?
57        } else {
58            client_builder.build()?
59        };
60
61        // Setup variables.
62        let root_dir = home_dir().unwrap().join(".sp1");
63        match fs::read_dir(&root_dir) {
64            Ok(entries) =>
65            {
66                #[allow(clippy::manual_flatten)]
67                for entry in entries {
68                    if let Ok(entry) = entry {
69                        let entry_path = entry.path();
70                        let entry_name = entry_path.file_name().unwrap();
71                        if entry_path.is_dir()
72                            && entry_name != "bin"
73                            && entry_name != "circuits"
74                            && entry_name != "toolchains"
75                        {
76                            if let Err(err) = fs::remove_dir_all(&entry_path) {
77                                println!("Failed to remove directory {entry_path:?}: {err}");
78                            }
79                        } else if entry_path.is_file() {
80                            if let Err(err) = fs::remove_file(&entry_path) {
81                                println!("Failed to remove file {entry_path:?}: {err}");
82                            }
83                        }
84                    }
85                }
86            }
87            Err(_) => println!("No existing ~/.sp1 directory to remove."),
88        }
89        println!("Successfully cleaned up ~/.sp1 directory.");
90        match fs::create_dir_all(&root_dir) {
91            Ok(_) => println!("Successfully created ~/.sp1 directory."),
92            Err(err) => println!("Failed to create ~/.sp1 directory: {err}"),
93        };
94
95        assert!(
96            is_supported_target(),
97            "Unsupported architecture. Please build the toolchain from source."
98        );
99        let target = get_target();
100        let toolchain_asset_name = format!("rust-toolchain-{target}.tar.gz");
101        let toolchain_archive_path = root_dir.join(toolchain_asset_name.clone());
102        let toolchain_dir = root_dir.join(&target);
103
104        let toolchain_download_url = get_toolchain_download_url(&client, target.to_string()).await;
105
106        let artifact_exists = url_exists(&client, toolchain_download_url.as_str()).await;
107        if !artifact_exists {
108            return Err(anyhow::anyhow!(
109                "Unsupported architecture. Please build the toolchain from source."
110            ));
111        }
112
113        // Download the toolchain.
114        let mut file = tokio::fs::File::create(toolchain_archive_path).await.unwrap();
115        download_file(&client, toolchain_download_url.as_str(), &mut file).await.unwrap();
116
117        // Remove the existing toolchain from rustup, if it exists.
118        let mut child = Command::new("rustup")
119            .current_dir(&root_dir)
120            .args(["toolchain", "remove", RUSTUP_TOOLCHAIN_NAME])
121            .stdout(std::process::Stdio::piped())
122            .spawn()?;
123        let res = child.wait();
124        match res {
125            Ok(_) => {
126                let mut stdout = child.stdout.take().unwrap();
127                let mut content = String::new();
128                stdout.read_to_string(&mut content).unwrap();
129                if !content.contains("no toolchain installed") {
130                    println!("Successfully removed existing toolchain.");
131                }
132            }
133            Err(_) => println!("Failed to remove existing toolchain."),
134        }
135
136        // Unpack the toolchain.
137        fs::create_dir_all(toolchain_dir.clone())?;
138        Command::new("tar")
139            .current_dir(&root_dir)
140            .args(["-xzf", &toolchain_asset_name, "-C", &toolchain_dir.to_string_lossy()])
141            .status()?;
142
143        // Move the toolchain to a randomly named directory in the 'toolchains' folder
144        let toolchains_dir = root_dir.join("toolchains");
145        fs::create_dir_all(&toolchains_dir)?;
146        let random_string: String =
147            rand::thread_rng().sample_iter(&Alphanumeric).take(10).map(char::from).collect();
148        let new_toolchain_dir = toolchains_dir.join(random_string);
149        fs::rename(&toolchain_dir, &new_toolchain_dir)?;
150
151        // Link the new toolchain directory to rustup
152        Command::new("rustup")
153            .current_dir(&root_dir)
154            .args([
155                "toolchain",
156                "link",
157                RUSTUP_TOOLCHAIN_NAME,
158                &new_toolchain_dir.to_string_lossy(),
159            ])
160            .status()?;
161        println!("Successfully linked toolchain to rustup.");
162
163        // Ensure permissions.
164        let bin_dir = new_toolchain_dir.join("bin");
165        let rustlib_bin_dir = new_toolchain_dir.join(format!("lib/rustlib/{target}/bin"));
166        for entry in fs::read_dir(bin_dir)?.chain(fs::read_dir(rustlib_bin_dir)?) {
167            let entry = entry?;
168            if entry.path().is_file() {
169                let mut perms = entry.metadata()?.permissions();
170                perms.set_mode(0o755);
171                fs::set_permissions(entry.path(), perms)?;
172            }
173        }
174
175        Ok(())
176    }
177}
178
179pub async fn download_file(
180    client: &Client,
181    url: &str,
182    file: &mut (impl tokio::io::AsyncWrite + Unpin),
183) -> std::result::Result<(), String> {
184    use futures::StreamExt;
185    use tokio::io::AsyncWriteExt;
186
187    let res = client.get(url).send().await.or(Err(format!("Failed to GET from '{}'", &url)))?;
188
189    let total_size =
190        res.content_length().ok_or(format!("Failed to get content length from '{}'", &url))?;
191
192    let pb = ProgressBar::new(total_size);
193    pb.set_style(ProgressStyle::default_bar()
194        .template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})").unwrap()
195        .progress_chars("#>-"));
196
197    let mut downloaded: u64 = 0;
198    let mut stream = res.bytes_stream();
199    while let Some(item) = stream.next().await {
200        let chunk = item.or(Err("Error while downloading file"))?;
201        file.write_all(&chunk).await.or(Err("Error while writing to file"))?;
202        let new = (downloaded + (chunk.len() as u64)).min(total_size);
203        downloaded = new;
204        pb.set_position(new);
205    }
206    pb.finish();
207
208    Ok(())
209}