Skip to main content

shell_download/
lib.rs

1#![doc = include_str!("../README.md")]
2
3mod drivers;
4mod tempfile;
5mod url_parser;
6mod util;
7
8use std::io;
9use std::path::Path;
10use std::sync::{
11    Arc,
12    atomic::{AtomicBool, Ordering},
13};
14use std::thread::JoinHandle;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17/// A supported download backend.
18pub enum Downloader {
19    /// Use `curl`.
20    Curl,
21    /// Use `wget`.
22    Wget,
23    /// Use PowerShell (`pwsh`/`powershell`).
24    PowerShell,
25    /// Use Python `urllib`.
26    Python3,
27    /// Speak HTTP/1.1 via `openssl s_client` or TCP socket (best-effort).
28    OpenSsl,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32/// Controls forwarding of child stdout/stderr.
33pub enum Quiet {
34    /// Never be quiet: always forward child stdout/stderr to the parent process.
35    Never,
36    /// Always be quiet: never forward child stdout/stderr.
37    Always,
38    /// Only be quiet on success: forward output if the command fails.
39    OnSuccess,
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43/// Response body content encoding (if known).
44pub enum ContentEncoding {
45    /// Gzip-compressed content.
46    Gzip,
47}
48
49#[derive(Debug, Clone)]
50/// Builder for a single download request.
51pub struct RequestBuilder {
52    pub(crate) url: String,
53    pub(crate) headers: Vec<(String, String)>,
54    pub(crate) preferred: Vec<Downloader>,
55    pub(crate) follow_redirects: bool,
56    pub(crate) quiet: Quiet,
57}
58
59#[derive(Debug, Clone)]
60/// Low-level download result prior to finalizing the output file.
61pub struct DownloadResult {
62    /// HTTP status code (best-effort).
63    pub status_code: u16,
64    /// Response content encoding, if known.
65    pub content_encoding: Option<ContentEncoding>,
66}
67
68impl RequestBuilder {
69    /// Create a new request builder.
70    pub fn new(url: impl Into<String>) -> Self {
71        Self {
72            url: url.into(),
73            headers: Vec::new(),
74            preferred: Vec::new(),
75            follow_redirects: true,
76            quiet: Quiet::Always,
77        }
78    }
79
80    /// Add an HTTP header.
81    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
82        self.headers.push((key.into(), value.into()));
83        self
84    }
85
86    /// Prefer a specific downloader backend.
87    pub fn preferred_downloader(mut self, preferred: Downloader) -> Self {
88        self.preferred.push(preferred);
89        self
90    }
91
92    /// Enable or disable HTTP redirect following.
93    pub fn follow_redirects(mut self, follow_redirects: bool) -> Self {
94        self.follow_redirects = follow_redirects;
95        self
96    }
97
98    /// Control forwarding of child output.
99    pub fn quiet(mut self, quiet: Quiet) -> Self {
100        self.quiet = quiet;
101        self
102    }
103
104    /// Fetch the response body as a String, blocking until the download is
105    /// complete.
106    pub fn fetch_string(self) -> Result<String, ResponseError> {
107        String::from_utf8(self.fetch_bytes()?)
108            .map_err(|e| ResponseError::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, e)))
109    }
110
111    /// Fetch the response body as a String, blocking until the download is
112    /// complete.
113    pub fn fetch_bytes(self) -> Result<Vec<u8>, ResponseError> {
114        let tmp = crate::tempfile::create_tmp_file_in_path(
115            "in-memory",
116            None,
117            &std::env::temp_dir(),
118            "shell-download-in-memory",
119        )
120        .map_err(ResponseError::Io)?;
121        let handle = self.start(&tmp).map_err(ResponseError::Start)?;
122        let _res = handle.join()?;
123        std::fs::read(&tmp).map_err(ResponseError::Io)
124    }
125
126    /// Start the download in a background thread.
127    pub fn start(self, target_path: impl AsRef<Path>) -> Result<RequestHandle, StartError> {
128        let target_path = target_path.as_ref().to_path_buf();
129
130        if let Some(parent) = target_path.parent() {
131            if !parent.as_os_str().is_empty() {
132                std::fs::create_dir_all(parent).map_err(StartError::IoError)?;
133            }
134        }
135
136        let _ = std::fs::remove_file(&target_path);
137
138        // URL preflight: fail early with a message useful to callers.
139        let url = url_parser::Url::new(&self.url).map_err(|e| StartError::Url(e.to_string()))?;
140
141        let parent = target_path.parent().unwrap_or_else(|| Path::new("."));
142        let hint = target_path
143            .file_name()
144            .and_then(|s| s.to_str())
145            .unwrap_or("download");
146        let tmp_path =
147            crate::tempfile::create_tmp_file_in_path("download", Some(&url), parent, hint)
148                .map_err(StartError::IoError)?;
149
150        let cancel = Arc::new(AtomicBool::new(false));
151        let mut saw_non_not_found: Option<io::Error> = None;
152        let mut saw_any_not_found = false;
153
154        for d in candidate_downloaders(&self.preferred) {
155            match d
156                .driver()
157                .start(self.clone(), tmp_path.as_ref(), Arc::clone(&cancel))
158            {
159                Ok(join) => {
160                    return Ok(RequestHandle {
161                        cancel,
162                        join: Some(join),
163                        target_path,
164                        tmp_path: Some(tmp_path),
165                    });
166                }
167                Err(StartError::NoDriverFound) => {
168                    saw_any_not_found = true;
169                    continue;
170                }
171                Err(StartError::IoError(e)) => {
172                    if saw_non_not_found.is_none() {
173                        saw_non_not_found = Some(e);
174                    }
175                    continue;
176                }
177                Err(StartError::Url(msg)) => return Err(StartError::Url(msg)),
178            }
179        }
180
181        if let Some(e) = saw_non_not_found {
182            return Err(StartError::IoError(e));
183        }
184        if saw_any_not_found {
185            return Err(StartError::NoDriverFound);
186        }
187        Err(StartError::NoDriverFound)
188    }
189}
190
191impl Downloader {
192    pub(crate) fn driver(self) -> &'static dyn drivers::Driver {
193        static CURL: drivers::curl::CurlDriver = drivers::curl::CurlDriver;
194        static WGET: drivers::wget::WgetDriver = drivers::wget::WgetDriver;
195        static POWERSHELL: drivers::powershell::PowerShellDriver =
196            drivers::powershell::PowerShellDriver;
197        static PYTHON3: drivers::python3::Python3Driver = drivers::python3::Python3Driver;
198        static OPENSSL: drivers::openssl::OpenSslDriver = drivers::openssl::OpenSslDriver;
199
200        match self {
201            Downloader::Curl => &CURL,
202            Downloader::Wget => &WGET,
203            Downloader::PowerShell => &POWERSHELL,
204            Downloader::Python3 => &PYTHON3,
205            Downloader::OpenSsl => &OPENSSL,
206        }
207    }
208}
209
210#[derive(Debug)]
211/// Handle for a running download.
212pub struct RequestHandle {
213    cancel: Arc<AtomicBool>,
214    join: Option<JoinHandle<Result<DownloadResult, ResponseError>>>,
215    target_path: std::path::PathBuf,
216    tmp_path: Option<crate::tempfile::TmpFile>,
217}
218
219impl RequestHandle {
220    /// Request cancellation (best-effort).
221    pub fn cancel(&self) {
222        self.cancel.store(true, Ordering::SeqCst);
223    }
224
225    /// Wait for completion and finalize the output file.
226    pub fn join(mut self) -> Result<Response, ResponseError> {
227        let res = match self.join.take().expect("join called once").join() {
228            Ok(r) => r,
229            Err(_) => Err(ResponseError::ThreadPanicked),
230        }?;
231
232        let tmp_path = self.tmp_path.take().expect("tmp_path present");
233        util::finalize_download(tmp_path, &self.target_path, res.content_encoding)?;
234        Ok(Response {
235            status_code: res.status_code,
236        })
237    }
238}
239
240impl Drop for RequestHandle {
241    fn drop(&mut self) {
242        if self.join.is_some() {
243            self.cancel.store(true, Ordering::SeqCst);
244            // `tmp_path` will clean itself up via `Drop`.
245        }
246    }
247}
248
249#[derive(Debug, Clone)]
250/// Final response metadata for a completed download.
251pub struct Response {
252    /// HTTP status code (best-effort).
253    pub status_code: u16,
254}
255
256#[derive(Debug)]
257/// Errors that can occur while starting a download.
258pub enum StartError {
259    /// No usable backend executable was found.
260    NoDriverFound,
261    /// A local I/O error occurred.
262    IoError(io::Error),
263    /// URL validation failed.
264    Url(String),
265}
266
267impl From<io::Error> for StartError {
268    fn from(value: io::Error) -> Self {
269        Self::IoError(value)
270    }
271}
272
273#[derive(Debug)]
274/// Errors that can occur while running a request.
275pub enum ResponseError {
276    /// A local I/O error occurred.
277    Io(io::Error),
278    /// The URL could not be parsed.
279    InvalidUrl,
280    /// The URL scheme is unsupported.
281    UnsupportedScheme,
282    /// The request was cancelled.
283    Cancelled,
284    /// The worker thread panicked.
285    ThreadPanicked,
286    /// The backend command failed.
287    CommandFailed {
288        /// Backend program label.
289        program: &'static str,
290        /// Process exit code, if available.
291        exit_code: Option<i32>,
292        /// Captured stderr (best-effort).
293        stderr: String,
294    },
295    /// The backend returned a non-numeric status code.
296    BadStatusCode(String),
297    /// Gzip decoding failed.
298    GzipFailed {
299        /// Process exit code, if available.
300        exit_code: Option<i32>,
301        /// Captured stderr (best-effort).
302        stderr: String,
303    },
304    /// Download start failed.
305    Start(StartError),
306}
307
308impl From<io::Error> for ResponseError {
309    fn from(value: io::Error) -> Self {
310        Self::Io(value)
311    }
312}
313
314/// Choose downloaders in priority order.
315fn candidate_downloaders(preferred: &[Downloader]) -> Vec<Downloader> {
316    if !preferred.is_empty() {
317        return preferred.to_vec();
318    }
319    vec![
320        Downloader::Curl,
321        Downloader::Wget,
322        Downloader::PowerShell,
323        Downloader::Python3,
324        Downloader::OpenSsl,
325    ]
326}