Skip to main content

shell_download/
lib.rs

1mod drivers;
2mod url_parser;
3mod util;
4
5use std::io;
6use std::path::Path;
7use std::sync::{
8    Arc,
9    atomic::{AtomicBool, Ordering},
10};
11use std::thread::JoinHandle;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum Downloader {
15    Curl,
16    Wget,
17    PowerShell,
18    OpenSsl,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum Quiet {
23    /// Never be quiet: always forward child stdout/stderr to the parent process.
24    Never,
25    /// Always be quiet: never forward child stdout/stderr.
26    Always,
27    /// Only be quiet on success: forward output if the command fails.
28    OnSuccess,
29}
30
31#[derive(Debug, Clone)]
32pub struct RequestBuilder {
33    pub(crate) url: String,
34    pub(crate) headers: Vec<(String, String)>,
35    pub(crate) preferred: Vec<Downloader>,
36    pub(crate) follow_redirects: bool,
37    pub(crate) quiet: Quiet,
38}
39
40#[derive(Debug, Clone)]
41pub struct DownloadResult {
42    pub status_code: u16,
43    pub content_encoding_gzip: bool,
44}
45
46impl RequestBuilder {
47    pub fn new(url: impl Into<String>) -> Self {
48        Self {
49            url: url.into(),
50            headers: Vec::new(),
51            preferred: Vec::new(),
52            follow_redirects: true,
53            quiet: Quiet::Always,
54        }
55    }
56
57    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
58        self.headers.push((key.into(), value.into()));
59        self
60    }
61
62    pub fn preferred_downloader(mut self, preferred: Downloader) -> Self {
63        self.preferred.push(preferred);
64        self
65    }
66
67    pub fn follow_redirects(mut self, follow_redirects: bool) -> Self {
68        self.follow_redirects = follow_redirects;
69        self
70    }
71
72    pub fn quiet(mut self, quiet: Quiet) -> Self {
73        self.quiet = quiet;
74        self
75    }
76
77    /// Fetch the response body as a String, blocking until the download is
78    /// complete.
79    #[cfg(feature = "in-memory")]
80    pub fn fetch_string(self) -> Result<String, ResponseError> {
81        let tmp_file = tempfile::NamedTempFile::new()?;
82        let handle = self.start(tmp_file.path()).map_err(ResponseError::Start)?;
83        let _res = handle.join()?;
84        std::fs::read_to_string(tmp_file.path()).map_err(ResponseError::Io)
85    }
86
87    /// Fetch the response body as a String, blocking until the download is
88    /// complete.
89    #[cfg(feature = "in-memory")]
90    pub fn fetch_bytes(self) -> Result<Vec<u8>, ResponseError> {
91        let tmp_file = tempfile::NamedTempFile::new()?;
92        let handle = self.start(tmp_file.path()).map_err(ResponseError::Start)?;
93        let _res = handle.join()?;
94        std::fs::read(tmp_file.path()).map_err(ResponseError::Io)
95    }
96
97    pub fn start(self, target_path: impl AsRef<Path>) -> Result<RequestHandle, StartError> {
98        let target_path = target_path.as_ref().to_path_buf();
99
100        if let Some(parent) = target_path.parent() {
101            if !parent.as_os_str().is_empty() {
102                std::fs::create_dir_all(parent).map_err(StartError::IoError)?;
103            }
104        }
105
106        let _ = std::fs::remove_file(&target_path);
107
108        // URL preflight: fail early with a message useful to callers.
109        url_parser::Url::new(&self.url).map_err(|e| StartError::Url(e.to_string()))?;
110
111        let tmp_path = util::tmp_path_for_target(&target_path);
112        let _ = std::fs::remove_file(&tmp_path);
113
114        let cancel = Arc::new(AtomicBool::new(false));
115        let mut saw_non_not_found: Option<io::Error> = None;
116        let mut saw_any_not_found = false;
117
118        for d in candidate_downloaders(&self.preferred) {
119            match d
120                .driver()
121                .start(self.clone(), tmp_path.clone(), Arc::clone(&cancel))
122            {
123                Ok(join) => {
124                    return Ok(RequestHandle {
125                        cancel,
126                        join: Some(join),
127                        target_path,
128                        tmp_path,
129                    });
130                }
131                Err(StartError::NoDriverFound) => {
132                    saw_any_not_found = true;
133                    continue;
134                }
135                Err(StartError::IoError(e)) => {
136                    if saw_non_not_found.is_none() {
137                        saw_non_not_found = Some(e);
138                    }
139                    continue;
140                }
141                Err(StartError::Url(msg)) => return Err(StartError::Url(msg)),
142            }
143        }
144
145        if let Some(e) = saw_non_not_found {
146            return Err(StartError::IoError(e));
147        }
148        if saw_any_not_found {
149            return Err(StartError::NoDriverFound);
150        }
151        Err(StartError::NoDriverFound)
152    }
153}
154
155impl Downloader {
156    pub(crate) fn driver(self) -> &'static dyn drivers::Driver {
157        static CURL: drivers::curl::CurlDriver = drivers::curl::CurlDriver;
158        static WGET: drivers::wget::WgetDriver = drivers::wget::WgetDriver;
159        static POWERSHELL: drivers::powershell::PowerShellDriver =
160            drivers::powershell::PowerShellDriver;
161        static OPENSSL: drivers::openssl::OpenSslDriver = drivers::openssl::OpenSslDriver;
162
163        match self {
164            Downloader::Curl => &CURL,
165            Downloader::Wget => &WGET,
166            Downloader::PowerShell => &POWERSHELL,
167            Downloader::OpenSsl => &OPENSSL,
168        }
169    }
170}
171
172#[derive(Debug)]
173pub struct RequestHandle {
174    cancel: Arc<AtomicBool>,
175    join: Option<JoinHandle<Result<DownloadResult, ResponseError>>>,
176    target_path: std::path::PathBuf,
177    tmp_path: std::path::PathBuf,
178}
179
180impl RequestHandle {
181    pub fn cancel(&self) {
182        self.cancel.store(true, Ordering::SeqCst);
183    }
184
185    pub fn join(mut self) -> Result<Response, ResponseError> {
186        let res = match self.join.take().expect("join called once").join() {
187            Ok(r) => r,
188            Err(_) => Err(ResponseError::ThreadPanicked),
189        }?;
190
191        util::finalize_download(&self.tmp_path, &self.target_path, res.content_encoding_gzip)?;
192        Ok(Response {
193            status_code: res.status_code,
194        })
195    }
196}
197
198impl Drop for RequestHandle {
199    fn drop(&mut self) {
200        if self.join.is_some() {
201            self.cancel.store(true, Ordering::SeqCst);
202            let _ = std::fs::remove_file(&self.tmp_path);
203        }
204    }
205}
206
207#[derive(Debug, Clone)]
208pub struct Response {
209    pub status_code: u16,
210}
211
212#[derive(Debug)]
213pub enum StartError {
214    NoDriverFound,
215    IoError(io::Error),
216    Url(String),
217}
218
219impl From<io::Error> for StartError {
220    fn from(value: io::Error) -> Self {
221        Self::IoError(value)
222    }
223}
224
225#[derive(Debug)]
226pub enum ResponseError {
227    Io(io::Error),
228    InvalidUrl,
229    UnsupportedScheme,
230    Cancelled,
231    ThreadPanicked,
232    CommandFailed {
233        program: &'static str,
234        exit_code: Option<i32>,
235        stderr: String,
236    },
237    BadStatusCode(String),
238    GzipFailed {
239        exit_code: Option<i32>,
240        stderr: String,
241    },
242    Start(StartError),
243}
244
245impl From<io::Error> for ResponseError {
246    fn from(value: io::Error) -> Self {
247        Self::Io(value)
248    }
249}
250
251fn candidate_downloaders(preferred: &[Downloader]) -> Vec<Downloader> {
252    if !preferred.is_empty() {
253        return preferred.to_vec();
254    }
255    vec![
256        Downloader::Curl,
257        Downloader::Wget,
258        Downloader::PowerShell,
259        Downloader::OpenSsl,
260    ]
261}