1mod drivers;
2mod url_parser;
3mod util;
4
5use std::io;
6use std::path::Path;
7use std::sync::{
8 Arc,
9 atomic::{AtomicBool, Ordering},
10};
11use std::thread::JoinHandle;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum Downloader {
15 Curl,
16 Wget,
17 PowerShell,
18 OpenSsl,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum Quiet {
23 Never,
25 Always,
27 OnSuccess,
29}
30
31#[derive(Debug, Clone)]
32pub struct RequestBuilder {
33 pub(crate) url: String,
34 pub(crate) headers: Vec<(String, String)>,
35 pub(crate) preferred: Vec<Downloader>,
36 pub(crate) follow_redirects: bool,
37 pub(crate) quiet: Quiet,
38}
39
40#[derive(Debug, Clone)]
41pub struct DownloadResult {
42 pub status_code: u16,
43 pub content_encoding_gzip: bool,
44}
45
46impl RequestBuilder {
47 pub fn new(url: impl Into<String>) -> Self {
48 Self {
49 url: url.into(),
50 headers: Vec::new(),
51 preferred: Vec::new(),
52 follow_redirects: true,
53 quiet: Quiet::Always,
54 }
55 }
56
57 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
58 self.headers.push((key.into(), value.into()));
59 self
60 }
61
62 pub fn preferred_downloader(mut self, preferred: Downloader) -> Self {
63 self.preferred.push(preferred);
64 self
65 }
66
67 pub fn follow_redirects(mut self, follow_redirects: bool) -> Self {
68 self.follow_redirects = follow_redirects;
69 self
70 }
71
72 pub fn quiet(mut self, quiet: Quiet) -> Self {
73 self.quiet = quiet;
74 self
75 }
76
77 #[cfg(feature = "in-memory")]
80 pub fn fetch_string(self) -> Result<String, ResponseError> {
81 let tmp_file = tempfile::NamedTempFile::new()?;
82 let handle = self
83 .start(tmp_file.path())
84 .map_err(ResponseError::Start)?;
85 let _res = handle.join()?;
86 std::fs::read_to_string(tmp_file.path()).map_err(ResponseError::Io)
87 }
88
89 #[cfg(feature = "in-memory")]
92 pub fn fetch_bytes(self) -> Result<Vec<u8>, ResponseError> {
93 let tmp_file = tempfile::NamedTempFile::new()?;
94 let handle = self
95 .start(tmp_file.path())
96 .map_err(ResponseError::Start)?;
97 let _res = handle.join()?;
98 std::fs::read(tmp_file.path()).map_err(ResponseError::Io)
99 }
100
101 pub fn start(self, target_path: impl AsRef<Path>) -> Result<RequestHandle, StartError> {
102 let target_path = target_path.as_ref().to_path_buf();
103
104 if let Some(parent) = target_path.parent() {
105 if !parent.as_os_str().is_empty() {
106 std::fs::create_dir_all(parent).map_err(StartError::IoError)?;
107 }
108 }
109
110 let _ = std::fs::remove_file(&target_path);
111
112 url_parser::Url::new(&self.url).map_err(|e| StartError::Url(e.to_string()))?;
114
115 let tmp_path = util::tmp_path_for_target(&target_path);
116 let _ = std::fs::remove_file(&tmp_path);
117
118 let cancel = Arc::new(AtomicBool::new(false));
119 let mut saw_non_not_found: Option<io::Error> = None;
120 let mut saw_any_not_found = false;
121
122 for d in candidate_downloaders(&self.preferred) {
123 match d
124 .driver()
125 .start(self.clone(), tmp_path.clone(), Arc::clone(&cancel))
126 {
127 Ok(join) => {
128 return Ok(RequestHandle {
129 cancel,
130 join: Some(join),
131 target_path,
132 tmp_path,
133 });
134 }
135 Err(StartError::NoDriverFound) => {
136 saw_any_not_found = true;
137 continue;
138 }
139 Err(StartError::IoError(e)) => {
140 if saw_non_not_found.is_none() {
141 saw_non_not_found = Some(e);
142 }
143 continue;
144 }
145 Err(StartError::Url(msg)) => return Err(StartError::Url(msg)),
146 }
147 }
148
149 if let Some(e) = saw_non_not_found {
150 return Err(StartError::IoError(e));
151 }
152 if saw_any_not_found {
153 return Err(StartError::NoDriverFound);
154 }
155 Err(StartError::NoDriverFound)
156 }
157}
158
159impl Downloader {
160 pub(crate) fn driver(self) -> &'static dyn drivers::Driver {
161 static CURL: drivers::curl::CurlDriver = drivers::curl::CurlDriver;
162 static WGET: drivers::wget::WgetDriver = drivers::wget::WgetDriver;
163 static POWERSHELL: drivers::powershell::PowerShellDriver =
164 drivers::powershell::PowerShellDriver;
165 static OPENSSL: drivers::openssl::OpenSslDriver = drivers::openssl::OpenSslDriver;
166
167 match self {
168 Downloader::Curl => &CURL,
169 Downloader::Wget => &WGET,
170 Downloader::PowerShell => &POWERSHELL,
171 Downloader::OpenSsl => &OPENSSL,
172 }
173 }
174}
175
176#[derive(Debug)]
177pub struct RequestHandle {
178 cancel: Arc<AtomicBool>,
179 join: Option<JoinHandle<Result<DownloadResult, ResponseError>>>,
180 target_path: std::path::PathBuf,
181 tmp_path: std::path::PathBuf,
182}
183
184impl RequestHandle {
185 pub fn cancel(&self) {
186 self.cancel.store(true, Ordering::SeqCst);
187 }
188
189 pub fn join(mut self) -> Result<Response, ResponseError> {
190 let res = match self.join.take().expect("join called once").join() {
191 Ok(r) => r,
192 Err(_) => Err(ResponseError::ThreadPanicked),
193 }?;
194
195 util::finalize_download(&self.tmp_path, &self.target_path, res.content_encoding_gzip)?;
196 Ok(Response {
197 status_code: res.status_code,
198 })
199 }
200}
201
202impl Drop for RequestHandle {
203 fn drop(&mut self) {
204 if self.join.is_some() {
205 self.cancel.store(true, Ordering::SeqCst);
206 let _ = std::fs::remove_file(&self.tmp_path);
207 }
208 }
209}
210
211#[derive(Debug, Clone)]
212pub struct Response {
213 pub status_code: u16,
214}
215
216#[derive(Debug)]
217pub enum StartError {
218 NoDriverFound,
219 IoError(io::Error),
220 Url(String),
221}
222
223impl From<io::Error> for StartError {
224 fn from(value: io::Error) -> Self {
225 Self::IoError(value)
226 }
227}
228
229#[derive(Debug)]
230pub enum ResponseError {
231 Io(io::Error),
232 InvalidUrl,
233 UnsupportedScheme,
234 Cancelled,
235 ThreadPanicked,
236 CommandFailed {
237 program: &'static str,
238 exit_code: Option<i32>,
239 stderr: String,
240 },
241 BadStatusCode(String),
242 GzipFailed {
243 exit_code: Option<i32>,
244 stderr: String,
245 },
246 Start(StartError),
247}
248
249impl From<io::Error> for ResponseError {
250 fn from(value: io::Error) -> Self {
251 Self::Io(value)
252 }
253}
254
255fn candidate_downloaders(preferred: &[Downloader]) -> Vec<Downloader> {
256 if !preferred.is_empty() {
257 return preferred.to_vec();
258 }
259 vec![
260 Downloader::Curl,
261 Downloader::Wget,
262 Downloader::PowerShell,
263 Downloader::OpenSsl,
264 ]
265}