rustup_private_download/
lib.rs

1//! Easy file downloading
2#![deny(rust_2018_idioms)]
3
4use std::path::Path;
5
6use anyhow::Context;
7pub use anyhow::Result;
8use url::Url;
9
10mod errors;
11pub use crate::errors::*;
12
13/// User agent header value for HTTP request.
14/// See: https://github.com/rust-lang/rustup/issues/2860.
15const USER_AGENT: &str = concat!("rustup/", env!("CARGO_PKG_VERSION"));
16
17#[derive(Debug, Copy, Clone)]
18pub enum Backend {
19    Curl,
20    Reqwest(TlsBackend),
21}
22
23#[derive(Debug, Copy, Clone)]
24pub enum TlsBackend {
25    Rustls,
26    Default,
27}
28
29#[derive(Debug, Copy, Clone)]
30pub enum Event<'a> {
31    ResumingPartialDownload,
32    /// Received the Content-Length of the to-be downloaded data.
33    DownloadContentLengthReceived(u64),
34    /// Received some data.
35    DownloadDataReceived(&'a [u8]),
36}
37
38fn download_with_backend(
39    backend: Backend,
40    url: &Url,
41    resume_from: u64,
42    callback: &dyn Fn(Event<'_>) -> Result<()>,
43) -> Result<()> {
44    match backend {
45        Backend::Curl => curl::download(url, resume_from, callback),
46        Backend::Reqwest(tls) => reqwest_be::download(url, resume_from, callback, tls),
47    }
48}
49
50pub fn download_to_path_with_backend(
51    backend: Backend,
52    url: &Url,
53    path: &Path,
54    resume_from_partial: bool,
55    callback: Option<&dyn Fn(Event<'_>) -> Result<()>>,
56) -> Result<()> {
57    use std::cell::RefCell;
58    use std::fs::remove_file;
59    use std::fs::OpenOptions;
60    use std::io::{Read, Seek, SeekFrom, Write};
61
62    || -> Result<()> {
63        let (file, resume_from) = if resume_from_partial {
64            let possible_partial = OpenOptions::new().read(true).open(&path);
65
66            let downloaded_so_far = if let Ok(mut partial) = possible_partial {
67                if let Some(cb) = callback {
68                    cb(Event::ResumingPartialDownload)?;
69
70                    let mut buf = vec![0; 32768];
71                    let mut downloaded_so_far = 0;
72                    loop {
73                        let n = partial.read(&mut buf)?;
74                        downloaded_so_far += n as u64;
75                        if n == 0 {
76                            break;
77                        }
78                        cb(Event::DownloadDataReceived(&buf[..n]))?;
79                    }
80
81                    downloaded_so_far
82                } else {
83                    let file_info = partial.metadata()?;
84                    file_info.len()
85                }
86            } else {
87                0
88            };
89
90            let mut possible_partial = OpenOptions::new()
91                .write(true)
92                .create(true)
93                .open(&path)
94                .context("error opening file for download")?;
95
96            possible_partial.seek(SeekFrom::End(0))?;
97
98            (possible_partial, downloaded_so_far)
99        } else {
100            (
101                OpenOptions::new()
102                    .write(true)
103                    .create(true)
104                    .open(&path)
105                    .context("error creating file for download")?,
106                0,
107            )
108        };
109
110        let file = RefCell::new(file);
111
112        download_with_backend(backend, url, resume_from, &|event| {
113            if let Event::DownloadDataReceived(data) = event {
114                file.borrow_mut()
115                    .write_all(data)
116                    .context("unable to write download to disk")?;
117            }
118            match callback {
119                Some(cb) => cb(event),
120                None => Ok(()),
121            }
122        })?;
123
124        file.borrow_mut()
125            .sync_data()
126            .context("unable to sync download to disk")?;
127
128        Ok(())
129    }()
130    .map_err(|e| {
131        // TODO: We currently clear up the cached download on any error, should we restrict it to a subset?
132        if let Err(file_err) = remove_file(path).context("cleaning up cached downloads") {
133            file_err.context(e)
134        } else {
135            e
136        }
137    })
138}
139
140/// Download via libcurl; encrypt with the native (or OpenSSl) TLS
141/// stack via libcurl
142#[cfg(feature = "curl-backend")]
143pub mod curl {
144    use std::cell::RefCell;
145    use std::str;
146    use std::time::Duration;
147
148    use anyhow::{Context, Result};
149    use curl::easy::Easy;
150    use url::Url;
151
152    use super::Event;
153    use crate::errors::*;
154
155    pub fn download(
156        url: &Url,
157        resume_from: u64,
158        callback: &dyn Fn(Event<'_>) -> Result<()>,
159    ) -> Result<()> {
160        // Fetch either a cached libcurl handle (which will preserve open
161        // connections) or create a new one if it isn't listed.
162        //
163        // Once we've acquired it, reset the lifetime from 'static to our local
164        // scope.
165        thread_local!(static EASY: RefCell<Easy> = RefCell::new(Easy::new()));
166        EASY.with(|handle| {
167            let mut handle = handle.borrow_mut();
168
169            handle.url(url.as_ref())?;
170            handle.follow_location(true)?;
171            handle.useragent(super::USER_AGENT)?;
172
173            if resume_from > 0 {
174                handle.resume_from(resume_from)?;
175            } else {
176                // an error here indicates that the range header isn't supported by underlying curl,
177                // so there's nothing to "clear" - safe to ignore this error.
178                let _ = handle.resume_from(0);
179            }
180
181            // Take at most 30s to connect
182            handle.connect_timeout(Duration::new(30, 0))?;
183
184            {
185                let cberr = RefCell::new(None);
186                let mut transfer = handle.transfer();
187
188                // Data callback for libcurl which is called with data that's
189                // downloaded. We just feed it into our hasher and also write it out
190                // to disk.
191                transfer.write_function(|data| {
192                    match callback(Event::DownloadDataReceived(data)) {
193                        Ok(()) => Ok(data.len()),
194                        Err(e) => {
195                            *cberr.borrow_mut() = Some(e);
196                            Ok(0)
197                        }
198                    }
199                })?;
200
201                // Listen for headers and parse out a `Content-Length` (case-insensitive) if it
202                // comes so we know how much we're downloading.
203                transfer.header_function(|header| {
204                    if let Ok(data) = str::from_utf8(header) {
205                        let prefix = "content-length: ";
206                        if data.to_ascii_lowercase().starts_with(prefix) {
207                            if let Ok(s) = data[prefix.len()..].trim().parse::<u64>() {
208                                let msg = Event::DownloadContentLengthReceived(s + resume_from);
209                                match callback(msg) {
210                                    Ok(()) => (),
211                                    Err(e) => {
212                                        *cberr.borrow_mut() = Some(e);
213                                        return false;
214                                    }
215                                }
216                            }
217                        }
218                    }
219                    true
220                })?;
221
222                // If an error happens check to see if we had a filesystem error up
223                // in `cberr`, but we always want to punt it up.
224                transfer.perform().or_else(|e| {
225                    // If the original error was generated by one of our
226                    // callbacks, return it.
227                    match cberr.borrow_mut().take() {
228                        Some(cberr) => Err(cberr),
229                        None => {
230                            // Otherwise, return the error from curl
231                            if e.is_file_couldnt_read_file() {
232                                Err(e).context(DownloadError::FileNotFound)
233                            } else {
234                                Err(e).context("error during download")?
235                            }
236                        }
237                    }
238                })?;
239            }
240
241            // If we didn't get a 20x or 0 ("OK" for files) then return an error
242            let code = handle.response_code()?;
243            match code {
244                0 | 200..=299 => {}
245                _ => {
246                    return Err(DownloadError::HttpStatus(code).into());
247                }
248            };
249
250            Ok(())
251        })
252    }
253}
254
255#[cfg(feature = "reqwest-backend")]
256pub mod reqwest_be {
257    use std::io;
258    use std::time::Duration;
259
260    use anyhow::{anyhow, Context, Result};
261    use lazy_static::lazy_static;
262    use reqwest::blocking::{Client, ClientBuilder, Response};
263    use reqwest::{header, Proxy};
264    use url::Url;
265
266    use super::Event;
267    use super::TlsBackend;
268    use crate::errors::*;
269
270    pub fn download(
271        url: &Url,
272        resume_from: u64,
273        callback: &dyn Fn(Event<'_>) -> Result<()>,
274        tls: TlsBackend,
275    ) -> Result<()> {
276        // Short-circuit reqwest for the "file:" URL scheme
277        if download_from_file_url(url, resume_from, callback)? {
278            return Ok(());
279        }
280
281        let mut res = request(url, resume_from, tls).context("failed to make network request")?;
282
283        if !res.status().is_success() {
284            let code: u16 = res.status().into();
285            return Err(anyhow!(DownloadError::HttpStatus(u32::from(code))));
286        }
287
288        let buffer_size = 0x10000;
289        let mut buffer = vec![0u8; buffer_size];
290
291        if let Some(len) = res.headers().get(header::CONTENT_LENGTH) {
292            // TODO possible issues during unwrap?
293            let len = len.to_str().unwrap().parse::<u64>().unwrap() + resume_from;
294            callback(Event::DownloadContentLengthReceived(len))?;
295        }
296
297        loop {
298            let bytes_read = io::Read::read(&mut res, &mut buffer)?;
299
300            if bytes_read != 0 {
301                callback(Event::DownloadDataReceived(&buffer[0..bytes_read]))?;
302            } else {
303                return Ok(());
304            }
305        }
306    }
307
308    fn client_generic() -> ClientBuilder {
309        Client::builder()
310            .gzip(false)
311            .user_agent(super::USER_AGENT)
312            .proxy(Proxy::custom(env_proxy))
313            .timeout(Duration::from_secs(30))
314    }
315    #[cfg(feature = "reqwest-rustls-tls")]
316    lazy_static! {
317        static ref CLIENT_RUSTLS_TLS: Client = {
318            let catcher = || {
319                client_generic().use_rustls_tls()
320                    .build()
321            };
322
323            // woah, an unwrap?!
324            // It's OK. This is the same as what is happening in curl.
325            //
326            // The curl::Easy::new() internally assert!s that the initialized
327            // Easy is not null. Inside reqwest, the errors here would be from
328            // the TLS library returning a null pointer as well.
329            catcher().unwrap()
330        };
331    }
332    #[cfg(feature = "reqwest-default-tls")]
333    lazy_static! {
334        static ref CLIENT_DEFAULT_TLS: Client = {
335            let catcher = || {
336                client_generic()
337                    .build()
338            };
339
340            // woah, an unwrap?!
341            // It's OK. This is the same as what is happening in curl.
342            //
343            // The curl::Easy::new() internally assert!s that the initialized
344            // Easy is not null. Inside reqwest, the errors here would be from
345            // the TLS library returning a null pointer as well.
346            catcher().unwrap()
347        };
348    }
349
350    fn env_proxy(url: &Url) -> Option<Url> {
351        env_proxy::for_url(url).to_url()
352    }
353
354    fn request(
355        url: &Url,
356        resume_from: u64,
357        backend: TlsBackend,
358    ) -> Result<Response, DownloadError> {
359        let client: &Client = match backend {
360            #[cfg(feature = "reqwest-rustls-tls")]
361            TlsBackend::Rustls => &CLIENT_RUSTLS_TLS,
362            #[cfg(not(feature = "reqwest-rustls-tls"))]
363            TlsBackend::Rustls => {
364                return Err(DownloadError::BackendUnavailable("reqwest rustls"));
365            }
366            #[cfg(feature = "reqwest-default-tls")]
367            TlsBackend::Default => &CLIENT_DEFAULT_TLS,
368            #[cfg(not(feature = "reqwest-default-tls"))]
369            TlsBackend::Default => {
370                return Err(DownloadError::BackendUnavailable("reqwest default TLS"));
371            }
372        };
373        let mut req = client.get(url.as_str());
374
375        if resume_from != 0 {
376            req = req.header(header::RANGE, format!("bytes={}-", resume_from));
377        }
378
379        Ok(req.send()?)
380    }
381
382    fn download_from_file_url(
383        url: &Url,
384        resume_from: u64,
385        callback: &dyn Fn(Event<'_>) -> Result<()>,
386    ) -> Result<bool> {
387        use std::fs;
388
389        // The file scheme is mostly for use by tests to mock the dist server
390        if url.scheme() == "file" {
391            let src = url
392                .to_file_path()
393                .map_err(|_| DownloadError::Message(format!("bogus file url: '{}'", url)))?;
394            if !src.is_file() {
395                // Because some of rustup's logic depends on checking
396                // the error when a downloaded file doesn't exist, make
397                // the file case return the same error value as the
398                // network case.
399                return Err(anyhow!(DownloadError::FileNotFound));
400            }
401
402            let mut f = fs::File::open(src).context("unable to open downloaded file")?;
403            io::Seek::seek(&mut f, io::SeekFrom::Start(resume_from))?;
404
405            let mut buffer = vec![0u8; 0x10000];
406            loop {
407                let bytes_read = io::Read::read(&mut f, &mut buffer)?;
408                if bytes_read == 0 {
409                    break;
410                }
411                callback(Event::DownloadDataReceived(&buffer[0..bytes_read]))?;
412            }
413
414            Ok(true)
415        } else {
416            Ok(false)
417        }
418    }
419}
420
421#[cfg(not(feature = "curl-backend"))]
422pub mod curl {
423
424    use anyhow::{anyhow, Result};
425
426    use super::Event;
427    use crate::errors::*;
428    use url::Url;
429
430    pub fn download(
431        _url: &Url,
432        _resume_from: u64,
433        _callback: &dyn Fn(Event<'_>) -> Result<()>,
434    ) -> Result<()> {
435        Err(anyhow!(DownloadError::BackendUnavailable("curl")))
436    }
437}
438
439#[cfg(not(feature = "reqwest-backend"))]
440pub mod reqwest_be {
441
442    use anyhow::{anyhow, Result};
443
444    use super::Event;
445    use super::TlsBackend;
446    use crate::errors::*;
447    use url::Url;
448
449    pub fn download(
450        _url: &Url,
451        _resume_from: u64,
452        _callback: &dyn Fn(Event<'_>) -> Result<()>,
453        _tls: TlsBackend,
454    ) -> Result<()> {
455        Err(anyhow!(DownloadError::BackendUnavailable("reqwest")))
456    }
457}