Skip to main content

shell_download/
lib.rs

1mod drivers;
2mod url_parser;
3mod util;
4
5use std::io;
6use std::path::Path;
7use std::sync::{
8    Arc,
9    atomic::{AtomicBool, Ordering},
10};
11use std::thread::JoinHandle;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum Downloader {
15    Curl,
16    Wget,
17    PowerShell,
18    OpenSsl,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum Quiet {
23    /// Never be quiet: always forward child stdout/stderr to the parent process.
24    Never,
25    /// Always be quiet: never forward child stdout/stderr.
26    Always,
27    /// Only be quiet on success: forward output if the command fails.
28    OnSuccess,
29}
30
31#[derive(Debug, Clone)]
32pub struct RequestBuilder {
33    pub(crate) url: String,
34    pub(crate) headers: Vec<(String, String)>,
35    pub(crate) preferred: Vec<Downloader>,
36    pub(crate) follow_redirects: bool,
37    pub(crate) quiet: Quiet,
38}
39
40#[derive(Debug, Clone)]
41pub struct DownloadResult {
42    pub status_code: u16,
43    pub content_encoding_gzip: bool,
44}
45
46impl RequestBuilder {
47    pub fn new(url: impl Into<String>) -> Self {
48        Self {
49            url: url.into(),
50            headers: Vec::new(),
51            preferred: Vec::new(),
52            follow_redirects: true,
53            quiet: Quiet::Always,
54        }
55    }
56
57    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
58        self.headers.push((key.into(), value.into()));
59        self
60    }
61
62    pub fn preferred_downloader(mut self, preferred: Downloader) -> Self {
63        self.preferred.push(preferred);
64        self
65    }
66
67    pub fn follow_redirects(mut self, follow_redirects: bool) -> Self {
68        self.follow_redirects = follow_redirects;
69        self
70    }
71
72    pub fn quiet(mut self, quiet: Quiet) -> Self {
73        self.quiet = quiet;
74        self
75    }
76
77    /// Fetch the response body as a String, blocking until the download is
78    /// complete.
79    #[cfg(feature = "in-memory")]
80    pub fn fetch_string(self) -> Result<String, ResponseError> {
81        let tmp_file = tempfile::NamedTempFile::new()?;
82        let handle = self
83            .start(tmp_file.path())
84            .map_err(ResponseError::Start)?;
85        let _res = handle.join()?;
86        std::fs::read_to_string(tmp_file.path()).map_err(ResponseError::Io)
87    }
88
89    /// Fetch the response body as a String, blocking until the download is
90    /// complete.
91    #[cfg(feature = "in-memory")]
92    pub fn fetch_bytes(self) -> Result<Vec<u8>, ResponseError> {
93        let tmp_file = tempfile::NamedTempFile::new()?;
94        let handle = self
95            .start(tmp_file.path())
96            .map_err(ResponseError::Start)?;
97        let _res = handle.join()?;
98        std::fs::read(tmp_file.path()).map_err(ResponseError::Io)
99    }
100
101    pub fn start(self, target_path: impl AsRef<Path>) -> Result<RequestHandle, StartError> {
102        let target_path = target_path.as_ref().to_path_buf();
103
104        if let Some(parent) = target_path.parent() {
105            if !parent.as_os_str().is_empty() {
106                std::fs::create_dir_all(parent).map_err(StartError::IoError)?;
107            }
108        }
109
110        let _ = std::fs::remove_file(&target_path);
111
112        // URL preflight: fail early with a message useful to callers.
113        url_parser::Url::new(&self.url).map_err(|e| StartError::Url(e.to_string()))?;
114
115        let tmp_path = util::tmp_path_for_target(&target_path);
116        let _ = std::fs::remove_file(&tmp_path);
117
118        let cancel = Arc::new(AtomicBool::new(false));
119        let mut saw_non_not_found: Option<io::Error> = None;
120        let mut saw_any_not_found = false;
121
122        for d in candidate_downloaders(&self.preferred) {
123            match d
124                .driver()
125                .start(self.clone(), tmp_path.clone(), Arc::clone(&cancel))
126            {
127                Ok(join) => {
128                    return Ok(RequestHandle {
129                        cancel,
130                        join: Some(join),
131                        target_path,
132                        tmp_path,
133                    });
134                }
135                Err(StartError::NoDriverFound) => {
136                    saw_any_not_found = true;
137                    continue;
138                }
139                Err(StartError::IoError(e)) => {
140                    if saw_non_not_found.is_none() {
141                        saw_non_not_found = Some(e);
142                    }
143                    continue;
144                }
145                Err(StartError::Url(msg)) => return Err(StartError::Url(msg)),
146            }
147        }
148
149        if let Some(e) = saw_non_not_found {
150            return Err(StartError::IoError(e));
151        }
152        if saw_any_not_found {
153            return Err(StartError::NoDriverFound);
154        }
155        Err(StartError::NoDriverFound)
156    }
157}
158
159impl Downloader {
160    pub(crate) fn driver(self) -> &'static dyn drivers::Driver {
161        static CURL: drivers::curl::CurlDriver = drivers::curl::CurlDriver;
162        static WGET: drivers::wget::WgetDriver = drivers::wget::WgetDriver;
163        static POWERSHELL: drivers::powershell::PowerShellDriver =
164            drivers::powershell::PowerShellDriver;
165        static OPENSSL: drivers::openssl::OpenSslDriver = drivers::openssl::OpenSslDriver;
166
167        match self {
168            Downloader::Curl => &CURL,
169            Downloader::Wget => &WGET,
170            Downloader::PowerShell => &POWERSHELL,
171            Downloader::OpenSsl => &OPENSSL,
172        }
173    }
174}
175
176#[derive(Debug)]
177pub struct RequestHandle {
178    cancel: Arc<AtomicBool>,
179    join: Option<JoinHandle<Result<DownloadResult, ResponseError>>>,
180    target_path: std::path::PathBuf,
181    tmp_path: std::path::PathBuf,
182}
183
184impl RequestHandle {
185    pub fn cancel(&self) {
186        self.cancel.store(true, Ordering::SeqCst);
187    }
188
189    pub fn join(mut self) -> Result<Response, ResponseError> {
190        let res = match self.join.take().expect("join called once").join() {
191            Ok(r) => r,
192            Err(_) => Err(ResponseError::ThreadPanicked),
193        }?;
194
195        util::finalize_download(&self.tmp_path, &self.target_path, res.content_encoding_gzip)?;
196        Ok(Response {
197            status_code: res.status_code,
198        })
199    }
200}
201
202impl Drop for RequestHandle {
203    fn drop(&mut self) {
204        if self.join.is_some() {
205            self.cancel.store(true, Ordering::SeqCst);
206            let _ = std::fs::remove_file(&self.tmp_path);
207        }
208    }
209}
210
211#[derive(Debug, Clone)]
212pub struct Response {
213    pub status_code: u16,
214}
215
216#[derive(Debug)]
217pub enum StartError {
218    NoDriverFound,
219    IoError(io::Error),
220    Url(String),
221}
222
223impl From<io::Error> for StartError {
224    fn from(value: io::Error) -> Self {
225        Self::IoError(value)
226    }
227}
228
229#[derive(Debug)]
230pub enum ResponseError {
231    Io(io::Error),
232    InvalidUrl,
233    UnsupportedScheme,
234    Cancelled,
235    ThreadPanicked,
236    CommandFailed {
237        program: &'static str,
238        exit_code: Option<i32>,
239        stderr: String,
240    },
241    BadStatusCode(String),
242    GzipFailed {
243        exit_code: Option<i32>,
244        stderr: String,
245    },
246    Start(StartError),
247}
248
249impl From<io::Error> for ResponseError {
250    fn from(value: io::Error) -> Self {
251        Self::Io(value)
252    }
253}
254
255fn candidate_downloaders(preferred: &[Downloader]) -> Vec<Downloader> {
256    if !preferred.is_empty() {
257        return preferred.to_vec();
258    }
259    vec![
260        Downloader::Curl,
261        Downloader::Wget,
262        Downloader::PowerShell,
263        Downloader::OpenSsl,
264    ]
265}