Skip to main content

shell_download/
lib.rs

1#![doc = include_str!("../README.md")]
2
3mod drivers;
4mod tempfile;
5mod url_parser;
6mod util;
7
8use std::io;
9use std::path::{Path, PathBuf};
10use std::sync::{
11    Arc,
12    atomic::{AtomicBool, Ordering},
13};
14use std::thread::JoinHandle;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17/// A supported download backend.
18pub enum Downloader {
19    /// Use `curl`.
20    Curl,
21    /// Use `wget`.
22    Wget,
23    /// Use PowerShell (`pwsh`/`powershell`).
24    PowerShell,
25    /// Use Python `urllib`.
26    Python3,
27    /// Speak HTTP/1.1 via `openssl s_client` or TCP socket (best-effort).
28    OpenSsl,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32/// Controls forwarding of child stdout/stderr.
33pub enum Quiet {
34    /// Never be quiet: always forward child stdout/stderr to the parent process.
35    Never,
36    /// Always be quiet: never forward child stdout/stderr.
37    Always,
38    /// Only be quiet on success: forward output if the command fails.
39    OnSuccess,
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43/// Response body content encoding (if known).
44pub enum ContentEncoding {
45    /// Gzip-compressed content.
46    Gzip,
47}
48
49#[derive(Debug, Clone)]
50/// Builder for a single download request.
51pub struct RequestBuilder {
52    pub(crate) url: String,
53    pub(crate) headers: Vec<(String, String)>,
54    pub(crate) preferred: Vec<Downloader>,
55    pub(crate) follow_redirects: bool,
56    pub(crate) quiet: Quiet,
57}
58
59#[derive(Debug, Clone)]
60/// Low-level download result prior to finalizing the output file.
61pub struct DownloadResult {
62    /// HTTP status code (best-effort).
63    pub status_code: u16,
64    /// Response content encoding, if known.
65    pub content_encoding: Option<ContentEncoding>,
66}
67
68impl RequestBuilder {
69    /// Create a new request builder.
70    pub fn new(url: impl Into<String>) -> Self {
71        Self {
72            url: url.into(),
73            headers: Vec::new(),
74            preferred: Vec::new(),
75            follow_redirects: true,
76            quiet: Quiet::Always,
77        }
78    }
79
80    /// Add an HTTP header.
81    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
82        self.headers.push((key.into(), value.into()));
83        self
84    }
85
86    /// Prefer a specific downloader backend.
87    pub fn preferred_downloader(mut self, preferred: Downloader) -> Self {
88        self.preferred.push(preferred);
89        self
90    }
91
92    /// Enable or disable HTTP redirect following.
93    pub fn follow_redirects(mut self, follow_redirects: bool) -> Self {
94        self.follow_redirects = follow_redirects;
95        self
96    }
97
98    /// Control forwarding of child output.
99    pub fn quiet(mut self, quiet: Quiet) -> Self {
100        self.quiet = quiet;
101        self
102    }
103
104    /// Fetch the response body as a String, blocking until the download is
105    /// complete.
106    pub fn fetch_string(self) -> Result<String, ResponseError> {
107        String::from_utf8(self.fetch_bytes()?)
108            .map_err(|e| ResponseError::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, e)))
109    }
110
111    /// Fetch the response body as a String, blocking until the download is
112    /// complete.
113    pub fn fetch_bytes(self) -> Result<Vec<u8>, ResponseError> {
114        // Reserve a final path but do not keep an open handle: `join` /
115        // `finalize_download` replace that path, which fails on Windows with
116        // ERROR_ACCESS_DENIED if our `TmpFile` still holds the file open.
117        let tmp = crate::tempfile::create_tmp_file_in_path(
118            "in-memory",
119            None,
120            &std::env::temp_dir(),
121            "shell-download-in-memory",
122        )
123        .map_err(ResponseError::Io)?;
124        let target_path = tmp.as_ref().to_path_buf();
125        drop(tmp);
126
127        let handle = self
128            .start_internal(target_path.clone())
129            .map_err(ResponseError::Start)?;
130        let _res = handle.join()?;
131
132        let out = std::fs::read(&target_path).map_err(ResponseError::Io)?;
133        let _ = std::fs::remove_file(&target_path);
134        Ok(out)
135    }
136
137    /// Start the download in a background thread.
138    pub fn start(self, target_path: impl AsRef<Path>) -> Result<RequestHandle, StartError> {
139        let target_path = target_path.as_ref().to_path_buf();
140
141        if let Some(parent) = target_path.parent() {
142            if !parent.as_os_str().is_empty() {
143                std::fs::create_dir_all(parent).map_err(StartError::IoError)?;
144            }
145        }
146
147        let _ = std::fs::remove_file(&target_path);
148        self.start_internal(target_path)
149    }
150
151    fn start_internal(self, target_path: PathBuf) -> Result<RequestHandle, StartError> {
152        // URL preflight: fail early with a message useful to callers.
153        let url = url_parser::Url::new(&self.url).map_err(|e| StartError::Url(e.to_string()))?;
154
155        let parent = target_path.parent().unwrap_or_else(|| Path::new("."));
156        let hint = target_path
157            .file_name()
158            .and_then(|s| s.to_str())
159            .unwrap_or("download");
160        let tmp_path =
161            crate::tempfile::create_tmp_file_in_path("download", Some(&url), parent, hint)
162                .map_err(StartError::IoError)?;
163
164        let cancel = Arc::new(AtomicBool::new(false));
165        let mut saw_non_not_found: Option<io::Error> = None;
166        let mut saw_any_not_found = false;
167
168        for d in candidate_downloaders(&self.preferred) {
169            match d
170                .driver()
171                .start(self.clone(), tmp_path.as_ref(), Arc::clone(&cancel))
172            {
173                Ok(join) => {
174                    return Ok(RequestHandle {
175                        cancel,
176                        join: Some(join),
177                        target_path,
178                        tmp_path: Some(tmp_path),
179                    });
180                }
181                Err(StartError::NoDriverFound) => {
182                    saw_any_not_found = true;
183                    continue;
184                }
185                Err(StartError::IoError(e)) => {
186                    if saw_non_not_found.is_none() {
187                        saw_non_not_found = Some(e);
188                    }
189                    continue;
190                }
191                Err(StartError::Url(msg)) => return Err(StartError::Url(msg)),
192            }
193        }
194
195        if let Some(e) = saw_non_not_found {
196            return Err(StartError::IoError(e));
197        }
198        if saw_any_not_found {
199            return Err(StartError::NoDriverFound);
200        }
201        Err(StartError::NoDriverFound)
202    }
203}
204
205impl Downloader {
206    pub(crate) fn driver(self) -> &'static dyn drivers::Driver {
207        static CURL: drivers::curl::CurlDriver = drivers::curl::CurlDriver;
208        static WGET: drivers::wget::WgetDriver = drivers::wget::WgetDriver;
209        static POWERSHELL: drivers::powershell::PowerShellDriver =
210            drivers::powershell::PowerShellDriver;
211        static PYTHON3: drivers::python3::Python3Driver = drivers::python3::Python3Driver;
212        static OPENSSL: drivers::openssl::OpenSslDriver = drivers::openssl::OpenSslDriver;
213
214        match self {
215            Downloader::Curl => &CURL,
216            Downloader::Wget => &WGET,
217            Downloader::PowerShell => &POWERSHELL,
218            Downloader::Python3 => &PYTHON3,
219            Downloader::OpenSsl => &OPENSSL,
220        }
221    }
222}
223
224#[derive(Debug)]
225/// Handle for a running download.
226pub struct RequestHandle {
227    cancel: Arc<AtomicBool>,
228    join: Option<JoinHandle<Result<DownloadResult, ResponseError>>>,
229    target_path: std::path::PathBuf,
230    tmp_path: Option<crate::tempfile::TmpFile>,
231}
232
233impl RequestHandle {
234    /// Request cancellation (best-effort).
235    pub fn cancel(&self) {
236        self.cancel.store(true, Ordering::SeqCst);
237    }
238
239    /// Wait for completion and finalize the output file.
240    pub fn join(mut self) -> Result<Response, ResponseError> {
241        let res = match self.join.take().expect("join called once").join() {
242            Ok(r) => r,
243            Err(_) => Err(ResponseError::ThreadPanicked),
244        }?;
245
246        let tmp_path = self.tmp_path.take().expect("tmp_path present");
247        util::finalize_download(tmp_path, &self.target_path, res.content_encoding)?;
248        Ok(Response {
249            status_code: res.status_code,
250        })
251    }
252}
253
254impl Drop for RequestHandle {
255    fn drop(&mut self) {
256        if self.join.is_some() {
257            self.cancel.store(true, Ordering::SeqCst);
258            // `tmp_path` will clean itself up via `Drop`.
259        }
260    }
261}
262
263#[derive(Debug, Clone)]
264/// Final response metadata for a completed download.
265pub struct Response {
266    /// HTTP status code (best-effort).
267    pub status_code: u16,
268}
269
270#[derive(Debug)]
271/// Errors that can occur while starting a download.
272pub enum StartError {
273    /// No usable backend executable was found.
274    NoDriverFound,
275    /// A local I/O error occurred.
276    IoError(io::Error),
277    /// URL validation failed.
278    Url(String),
279}
280
281impl From<io::Error> for StartError {
282    fn from(value: io::Error) -> Self {
283        Self::IoError(value)
284    }
285}
286
287#[derive(Debug)]
288/// Errors that can occur while running a request.
289pub enum ResponseError {
290    /// A local I/O error occurred.
291    Io(io::Error),
292    /// The URL could not be parsed.
293    InvalidUrl,
294    /// The URL scheme is unsupported.
295    UnsupportedScheme,
296    /// The request was cancelled.
297    Cancelled,
298    /// The worker thread panicked.
299    ThreadPanicked,
300    /// The backend command failed.
301    CommandFailed {
302        /// Backend program label.
303        program: &'static str,
304        /// Process exit code, if available.
305        exit_code: Option<i32>,
306        /// Captured stderr (best-effort).
307        stderr: String,
308    },
309    /// The backend returned a non-numeric status code.
310    BadStatusCode(String),
311    /// Gzip decoding failed.
312    GzipFailed {
313        /// Process exit code, if available.
314        exit_code: Option<i32>,
315        /// Captured stderr (best-effort).
316        stderr: String,
317    },
318    /// Download start failed.
319    Start(StartError),
320}
321
322impl From<io::Error> for ResponseError {
323    fn from(value: io::Error) -> Self {
324        Self::Io(value)
325    }
326}
327
328/// Choose downloaders in priority order.
329fn candidate_downloaders(preferred: &[Downloader]) -> Vec<Downloader> {
330    if !preferred.is_empty() {
331        return preferred.to_vec();
332    }
333    vec![
334        Downloader::Curl,
335        Downloader::Wget,
336        Downloader::PowerShell,
337        Downloader::Python3,
338        Downloader::OpenSsl,
339    ]
340}