1#![deny(rust_2018_idioms)]
3
4use std::path::Path;
5
6use anyhow::Context;
7pub use anyhow::Result;
8use url::Url;
9
10mod errors;
11pub use crate::errors::*;
12
13const USER_AGENT: &str = concat!("rustup/", env!("CARGO_PKG_VERSION"));
16
17#[derive(Debug, Copy, Clone)]
18pub enum Backend {
19 Curl,
20 Reqwest(TlsBackend),
21}
22
23#[derive(Debug, Copy, Clone)]
24pub enum TlsBackend {
25 Rustls,
26 Default,
27}
28
29#[derive(Debug, Copy, Clone)]
30pub enum Event<'a> {
31 ResumingPartialDownload,
32 DownloadContentLengthReceived(u64),
34 DownloadDataReceived(&'a [u8]),
36}
37
38fn download_with_backend(
39 backend: Backend,
40 url: &Url,
41 resume_from: u64,
42 callback: &dyn Fn(Event<'_>) -> Result<()>,
43) -> Result<()> {
44 match backend {
45 Backend::Curl => curl::download(url, resume_from, callback),
46 Backend::Reqwest(tls) => reqwest_be::download(url, resume_from, callback, tls),
47 }
48}
49
50pub fn download_to_path_with_backend(
51 backend: Backend,
52 url: &Url,
53 path: &Path,
54 resume_from_partial: bool,
55 callback: Option<&dyn Fn(Event<'_>) -> Result<()>>,
56) -> Result<()> {
57 use std::cell::RefCell;
58 use std::fs::remove_file;
59 use std::fs::OpenOptions;
60 use std::io::{Read, Seek, SeekFrom, Write};
61
62 || -> Result<()> {
63 let (file, resume_from) = if resume_from_partial {
64 let possible_partial = OpenOptions::new().read(true).open(&path);
65
66 let downloaded_so_far = if let Ok(mut partial) = possible_partial {
67 if let Some(cb) = callback {
68 cb(Event::ResumingPartialDownload)?;
69
70 let mut buf = vec![0; 32768];
71 let mut downloaded_so_far = 0;
72 loop {
73 let n = partial.read(&mut buf)?;
74 downloaded_so_far += n as u64;
75 if n == 0 {
76 break;
77 }
78 cb(Event::DownloadDataReceived(&buf[..n]))?;
79 }
80
81 downloaded_so_far
82 } else {
83 let file_info = partial.metadata()?;
84 file_info.len()
85 }
86 } else {
87 0
88 };
89
90 let mut possible_partial = OpenOptions::new()
91 .write(true)
92 .create(true)
93 .open(&path)
94 .context("error opening file for download")?;
95
96 possible_partial.seek(SeekFrom::End(0))?;
97
98 (possible_partial, downloaded_so_far)
99 } else {
100 (
101 OpenOptions::new()
102 .write(true)
103 .create(true)
104 .open(&path)
105 .context("error creating file for download")?,
106 0,
107 )
108 };
109
110 let file = RefCell::new(file);
111
112 download_with_backend(backend, url, resume_from, &|event| {
113 if let Event::DownloadDataReceived(data) = event {
114 file.borrow_mut()
115 .write_all(data)
116 .context("unable to write download to disk")?;
117 }
118 match callback {
119 Some(cb) => cb(event),
120 None => Ok(()),
121 }
122 })?;
123
124 file.borrow_mut()
125 .sync_data()
126 .context("unable to sync download to disk")?;
127
128 Ok(())
129 }()
130 .map_err(|e| {
131 if let Err(file_err) = remove_file(path).context("cleaning up cached downloads") {
133 file_err.context(e)
134 } else {
135 e
136 }
137 })
138}
139
140#[cfg(feature = "curl-backend")]
143pub mod curl {
144 use std::cell::RefCell;
145 use std::str;
146 use std::time::Duration;
147
148 use anyhow::{Context, Result};
149 use curl::easy::Easy;
150 use url::Url;
151
152 use super::Event;
153 use crate::errors::*;
154
155 pub fn download(
156 url: &Url,
157 resume_from: u64,
158 callback: &dyn Fn(Event<'_>) -> Result<()>,
159 ) -> Result<()> {
160 thread_local!(static EASY: RefCell<Easy> = RefCell::new(Easy::new()));
166 EASY.with(|handle| {
167 let mut handle = handle.borrow_mut();
168
169 handle.url(url.as_ref())?;
170 handle.follow_location(true)?;
171 handle.useragent(super::USER_AGENT)?;
172
173 if resume_from > 0 {
174 handle.resume_from(resume_from)?;
175 } else {
176 let _ = handle.resume_from(0);
179 }
180
181 handle.connect_timeout(Duration::new(30, 0))?;
183
184 {
185 let cberr = RefCell::new(None);
186 let mut transfer = handle.transfer();
187
188 transfer.write_function(|data| {
192 match callback(Event::DownloadDataReceived(data)) {
193 Ok(()) => Ok(data.len()),
194 Err(e) => {
195 *cberr.borrow_mut() = Some(e);
196 Ok(0)
197 }
198 }
199 })?;
200
201 transfer.header_function(|header| {
204 if let Ok(data) = str::from_utf8(header) {
205 let prefix = "content-length: ";
206 if data.to_ascii_lowercase().starts_with(prefix) {
207 if let Ok(s) = data[prefix.len()..].trim().parse::<u64>() {
208 let msg = Event::DownloadContentLengthReceived(s + resume_from);
209 match callback(msg) {
210 Ok(()) => (),
211 Err(e) => {
212 *cberr.borrow_mut() = Some(e);
213 return false;
214 }
215 }
216 }
217 }
218 }
219 true
220 })?;
221
222 transfer.perform().or_else(|e| {
225 match cberr.borrow_mut().take() {
228 Some(cberr) => Err(cberr),
229 None => {
230 if e.is_file_couldnt_read_file() {
232 Err(e).context(DownloadError::FileNotFound)
233 } else {
234 Err(e).context("error during download")?
235 }
236 }
237 }
238 })?;
239 }
240
241 let code = handle.response_code()?;
243 match code {
244 0 | 200..=299 => {}
245 _ => {
246 return Err(DownloadError::HttpStatus(code).into());
247 }
248 };
249
250 Ok(())
251 })
252 }
253}
254
255#[cfg(feature = "reqwest-backend")]
256pub mod reqwest_be {
257 use std::io;
258 use std::time::Duration;
259
260 use anyhow::{anyhow, Context, Result};
261 use lazy_static::lazy_static;
262 use reqwest::blocking::{Client, ClientBuilder, Response};
263 use reqwest::{header, Proxy};
264 use url::Url;
265
266 use super::Event;
267 use super::TlsBackend;
268 use crate::errors::*;
269
270 pub fn download(
271 url: &Url,
272 resume_from: u64,
273 callback: &dyn Fn(Event<'_>) -> Result<()>,
274 tls: TlsBackend,
275 ) -> Result<()> {
276 if download_from_file_url(url, resume_from, callback)? {
278 return Ok(());
279 }
280
281 let mut res = request(url, resume_from, tls).context("failed to make network request")?;
282
283 if !res.status().is_success() {
284 let code: u16 = res.status().into();
285 return Err(anyhow!(DownloadError::HttpStatus(u32::from(code))));
286 }
287
288 let buffer_size = 0x10000;
289 let mut buffer = vec![0u8; buffer_size];
290
291 if let Some(len) = res.headers().get(header::CONTENT_LENGTH) {
292 let len = len.to_str().unwrap().parse::<u64>().unwrap() + resume_from;
294 callback(Event::DownloadContentLengthReceived(len))?;
295 }
296
297 loop {
298 let bytes_read = io::Read::read(&mut res, &mut buffer)?;
299
300 if bytes_read != 0 {
301 callback(Event::DownloadDataReceived(&buffer[0..bytes_read]))?;
302 } else {
303 return Ok(());
304 }
305 }
306 }
307
308 fn client_generic() -> ClientBuilder {
309 Client::builder()
310 .gzip(false)
311 .user_agent(super::USER_AGENT)
312 .proxy(Proxy::custom(env_proxy))
313 .timeout(Duration::from_secs(30))
314 }
315 #[cfg(feature = "reqwest-rustls-tls")]
316 lazy_static! {
317 static ref CLIENT_RUSTLS_TLS: Client = {
318 let catcher = || {
319 client_generic().use_rustls_tls()
320 .build()
321 };
322
323 catcher().unwrap()
330 };
331 }
332 #[cfg(feature = "reqwest-default-tls")]
333 lazy_static! {
334 static ref CLIENT_DEFAULT_TLS: Client = {
335 let catcher = || {
336 client_generic()
337 .build()
338 };
339
340 catcher().unwrap()
347 };
348 }
349
350 fn env_proxy(url: &Url) -> Option<Url> {
351 env_proxy::for_url(url).to_url()
352 }
353
354 fn request(
355 url: &Url,
356 resume_from: u64,
357 backend: TlsBackend,
358 ) -> Result<Response, DownloadError> {
359 let client: &Client = match backend {
360 #[cfg(feature = "reqwest-rustls-tls")]
361 TlsBackend::Rustls => &CLIENT_RUSTLS_TLS,
362 #[cfg(not(feature = "reqwest-rustls-tls"))]
363 TlsBackend::Rustls => {
364 return Err(DownloadError::BackendUnavailable("reqwest rustls"));
365 }
366 #[cfg(feature = "reqwest-default-tls")]
367 TlsBackend::Default => &CLIENT_DEFAULT_TLS,
368 #[cfg(not(feature = "reqwest-default-tls"))]
369 TlsBackend::Default => {
370 return Err(DownloadError::BackendUnavailable("reqwest default TLS"));
371 }
372 };
373 let mut req = client.get(url.as_str());
374
375 if resume_from != 0 {
376 req = req.header(header::RANGE, format!("bytes={}-", resume_from));
377 }
378
379 Ok(req.send()?)
380 }
381
382 fn download_from_file_url(
383 url: &Url,
384 resume_from: u64,
385 callback: &dyn Fn(Event<'_>) -> Result<()>,
386 ) -> Result<bool> {
387 use std::fs;
388
389 if url.scheme() == "file" {
391 let src = url
392 .to_file_path()
393 .map_err(|_| DownloadError::Message(format!("bogus file url: '{}'", url)))?;
394 if !src.is_file() {
395 return Err(anyhow!(DownloadError::FileNotFound));
400 }
401
402 let mut f = fs::File::open(src).context("unable to open downloaded file")?;
403 io::Seek::seek(&mut f, io::SeekFrom::Start(resume_from))?;
404
405 let mut buffer = vec![0u8; 0x10000];
406 loop {
407 let bytes_read = io::Read::read(&mut f, &mut buffer)?;
408 if bytes_read == 0 {
409 break;
410 }
411 callback(Event::DownloadDataReceived(&buffer[0..bytes_read]))?;
412 }
413
414 Ok(true)
415 } else {
416 Ok(false)
417 }
418 }
419}
420
421#[cfg(not(feature = "curl-backend"))]
422pub mod curl {
423
424 use anyhow::{anyhow, Result};
425
426 use super::Event;
427 use crate::errors::*;
428 use url::Url;
429
430 pub fn download(
431 _url: &Url,
432 _resume_from: u64,
433 _callback: &dyn Fn(Event<'_>) -> Result<()>,
434 ) -> Result<()> {
435 Err(anyhow!(DownloadError::BackendUnavailable("curl")))
436 }
437}
438
439#[cfg(not(feature = "reqwest-backend"))]
440pub mod reqwest_be {
441
442 use anyhow::{anyhow, Result};
443
444 use super::Event;
445 use super::TlsBackend;
446 use crate::errors::*;
447 use url::Url;
448
449 pub fn download(
450 _url: &Url,
451 _resume_from: u64,
452 _callback: &dyn Fn(Event<'_>) -> Result<()>,
453 _tls: TlsBackend,
454 ) -> Result<()> {
455 Err(anyhow!(DownloadError::BackendUnavailable("reqwest")))
456 }
457}