sp1_cli/commands/
install_toolchain.rs1use std::{
2 fs::{self},
3 io::Read,
4 process::Command,
5};
6
7use anyhow::Result;
8use clap::Parser;
9use dirs::home_dir;
10use indicatif::{ProgressBar, ProgressStyle};
11use rand::{distributions::Alphanumeric, Rng};
12use reqwest::Client;
13
14#[cfg(target_family = "unix")]
15use std::os::unix::fs::PermissionsExt;
16
17use crate::{
18 get_target, get_toolchain_download_url, is_supported_target, url_exists, RUSTUP_TOOLCHAIN_NAME,
19};
20
21#[derive(Parser)]
22#[command(name = "install-toolchain", about = "Install the cargo-prove toolchain.")]
23pub struct InstallToolchainCmd {
24 #[arg(short, long, env = "GITHUB_TOKEN")]
25 pub token: Option<String>,
26}
27
28impl InstallToolchainCmd {
29 #[allow(clippy::uninlined_format_args)]
30 pub async fn run(&self) -> Result<()> {
31 if Command::new("rustup")
33 .arg("--version")
34 .stdout(std::process::Stdio::null())
35 .stderr(std::process::Stdio::null())
36 .status()
37 .is_err()
38 {
39 return Err(anyhow::anyhow!(
40 "Rust is not installed. Please install Rust from https://rustup.rs/ and try again."
41 ));
42 }
43
44 let client_builder = Client::builder().user_agent("Mozilla/5.0");
46 let client = if let Some(ref token) = self.token {
47 client_builder
48 .default_headers({
49 let mut headers = reqwest::header::HeaderMap::new();
50 headers.insert(
51 reqwest::header::AUTHORIZATION,
52 reqwest::header::HeaderValue::from_str(&format!("token {token}")).unwrap(),
53 );
54 headers
55 })
56 .build()?
57 } else {
58 client_builder.build()?
59 };
60
61 let root_dir = home_dir().unwrap().join(".sp1");
63 match fs::read_dir(&root_dir) {
64 Ok(entries) =>
65 {
66 #[allow(clippy::manual_flatten)]
67 for entry in entries {
68 if let Ok(entry) = entry {
69 let entry_path = entry.path();
70 let entry_name = entry_path.file_name().unwrap();
71 if entry_path.is_dir()
72 && entry_name != "bin"
73 && entry_name != "circuits"
74 && entry_name != "toolchains"
75 {
76 if let Err(err) = fs::remove_dir_all(&entry_path) {
77 println!("Failed to remove directory {entry_path:?}: {err}");
78 }
79 } else if entry_path.is_file() {
80 if let Err(err) = fs::remove_file(&entry_path) {
81 println!("Failed to remove file {entry_path:?}: {err}");
82 }
83 }
84 }
85 }
86 }
87 Err(_) => println!("No existing ~/.sp1 directory to remove."),
88 }
89 println!("Successfully cleaned up ~/.sp1 directory.");
90 match fs::create_dir_all(&root_dir) {
91 Ok(_) => println!("Successfully created ~/.sp1 directory."),
92 Err(err) => println!("Failed to create ~/.sp1 directory: {err}"),
93 };
94
95 assert!(
96 is_supported_target(),
97 "Unsupported architecture. Please build the toolchain from source."
98 );
99 let target = get_target();
100 let toolchain_asset_name = format!("rust-toolchain-{target}.tar.gz");
101 let toolchain_archive_path = root_dir.join(toolchain_asset_name.clone());
102 let toolchain_dir = root_dir.join(&target);
103
104 let toolchain_download_url = get_toolchain_download_url(&client, target.to_string()).await;
105
106 let artifact_exists = url_exists(&client, toolchain_download_url.as_str()).await;
107 if !artifact_exists {
108 return Err(anyhow::anyhow!(
109 "Unsupported architecture. Please build the toolchain from source."
110 ));
111 }
112
113 let mut file = tokio::fs::File::create(toolchain_archive_path).await.unwrap();
115 download_file(&client, toolchain_download_url.as_str(), &mut file).await.unwrap();
116
117 let mut child = Command::new("rustup")
119 .current_dir(&root_dir)
120 .args(["toolchain", "remove", RUSTUP_TOOLCHAIN_NAME])
121 .stdout(std::process::Stdio::piped())
122 .spawn()?;
123 let res = child.wait();
124 match res {
125 Ok(_) => {
126 let mut stdout = child.stdout.take().unwrap();
127 let mut content = String::new();
128 stdout.read_to_string(&mut content).unwrap();
129 if !content.contains("no toolchain installed") {
130 println!("Successfully removed existing toolchain.");
131 }
132 }
133 Err(_) => println!("Failed to remove existing toolchain."),
134 }
135
136 fs::create_dir_all(toolchain_dir.clone())?;
138 Command::new("tar")
139 .current_dir(&root_dir)
140 .args(["-xzf", &toolchain_asset_name, "-C", &toolchain_dir.to_string_lossy()])
141 .status()?;
142
143 let toolchains_dir = root_dir.join("toolchains");
145 fs::create_dir_all(&toolchains_dir)?;
146 let random_string: String =
147 rand::thread_rng().sample_iter(&Alphanumeric).take(10).map(char::from).collect();
148 let new_toolchain_dir = toolchains_dir.join(random_string);
149 fs::rename(&toolchain_dir, &new_toolchain_dir)?;
150
151 Command::new("rustup")
153 .current_dir(&root_dir)
154 .args([
155 "toolchain",
156 "link",
157 RUSTUP_TOOLCHAIN_NAME,
158 &new_toolchain_dir.to_string_lossy(),
159 ])
160 .status()?;
161 println!("Successfully linked toolchain to rustup.");
162
163 let bin_dir = new_toolchain_dir.join("bin");
165 let rustlib_bin_dir = new_toolchain_dir.join(format!("lib/rustlib/{target}/bin"));
166 for entry in fs::read_dir(bin_dir)?.chain(fs::read_dir(rustlib_bin_dir)?) {
167 let entry = entry?;
168 if entry.path().is_file() {
169 let mut perms = entry.metadata()?.permissions();
170 perms.set_mode(0o755);
171 fs::set_permissions(entry.path(), perms)?;
172 }
173 }
174
175 Ok(())
176 }
177}
178
179pub async fn download_file(
180 client: &Client,
181 url: &str,
182 file: &mut (impl tokio::io::AsyncWrite + Unpin),
183) -> std::result::Result<(), String> {
184 use futures::StreamExt;
185 use tokio::io::AsyncWriteExt;
186
187 let res = client.get(url).send().await.or(Err(format!("Failed to GET from '{}'", &url)))?;
188
189 let total_size =
190 res.content_length().ok_or(format!("Failed to get content length from '{}'", &url))?;
191
192 let pb = ProgressBar::new(total_size);
193 pb.set_style(ProgressStyle::default_bar()
194 .template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})").unwrap()
195 .progress_chars("#>-"));
196
197 let mut downloaded: u64 = 0;
198 let mut stream = res.bytes_stream();
199 while let Some(item) = stream.next().await {
200 let chunk = item.or(Err("Error while downloading file"))?;
201 file.write_all(&chunk).await.or(Err("Error while writing to file"))?;
202 let new = (downloaded + (chunk.len() as u64)).min(total_size);
203 downloaded = new;
204 pb.set_position(new);
205 }
206 pb.finish();
207
208 Ok(())
209}