1mod drivers;
2mod url_parser;
3mod util;
4
5use std::io;
6use std::path::Path;
7use std::sync::{
8 Arc,
9 atomic::{AtomicBool, Ordering},
10};
11use std::thread::JoinHandle;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum Downloader {
15 Curl,
16 Wget,
17 PowerShell,
18 OpenSsl,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum Quiet {
23 Never,
25 Always,
27 OnSuccess,
29}
30
31#[derive(Debug, Clone)]
32pub struct RequestBuilder {
33 pub(crate) url: String,
34 pub(crate) headers: Vec<(String, String)>,
35 pub(crate) preferred: Vec<Downloader>,
36 pub(crate) follow_redirects: bool,
37 pub(crate) quiet: Quiet,
38}
39
40#[derive(Debug, Clone)]
41pub struct DownloadResult {
42 pub status_code: u16,
43 pub content_encoding_gzip: bool,
44}
45
46impl RequestBuilder {
47 pub fn new(url: impl Into<String>) -> Self {
48 Self {
49 url: url.into(),
50 headers: Vec::new(),
51 preferred: Vec::new(),
52 follow_redirects: true,
53 quiet: Quiet::Always,
54 }
55 }
56
57 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
58 self.headers.push((key.into(), value.into()));
59 self
60 }
61
62 pub fn preferred_downloader(mut self, preferred: Downloader) -> Self {
63 self.preferred.push(preferred);
64 self
65 }
66
67 pub fn follow_redirects(mut self, follow_redirects: bool) -> Self {
68 self.follow_redirects = follow_redirects;
69 self
70 }
71
72 pub fn quiet(mut self, quiet: Quiet) -> Self {
73 self.quiet = quiet;
74 self
75 }
76
77 #[cfg(feature = "in-memory")]
80 pub fn fetch_string(self) -> Result<String, ResponseError> {
81 let tmp_file = tempfile::NamedTempFile::new()?;
82 let handle = self.start(tmp_file.path()).map_err(ResponseError::Start)?;
83 let _res = handle.join()?;
84 std::fs::read_to_string(tmp_file.path()).map_err(ResponseError::Io)
85 }
86
87 #[cfg(feature = "in-memory")]
90 pub fn fetch_bytes(self) -> Result<Vec<u8>, ResponseError> {
91 let tmp_file = tempfile::NamedTempFile::new()?;
92 let handle = self.start(tmp_file.path()).map_err(ResponseError::Start)?;
93 let _res = handle.join()?;
94 std::fs::read(tmp_file.path()).map_err(ResponseError::Io)
95 }
96
97 pub fn start(self, target_path: impl AsRef<Path>) -> Result<RequestHandle, StartError> {
98 let target_path = target_path.as_ref().to_path_buf();
99
100 if let Some(parent) = target_path.parent() {
101 if !parent.as_os_str().is_empty() {
102 std::fs::create_dir_all(parent).map_err(StartError::IoError)?;
103 }
104 }
105
106 let _ = std::fs::remove_file(&target_path);
107
108 url_parser::Url::new(&self.url).map_err(|e| StartError::Url(e.to_string()))?;
110
111 let tmp_path = util::tmp_path_for_target(&target_path);
112 let _ = std::fs::remove_file(&tmp_path);
113
114 let cancel = Arc::new(AtomicBool::new(false));
115 let mut saw_non_not_found: Option<io::Error> = None;
116 let mut saw_any_not_found = false;
117
118 for d in candidate_downloaders(&self.preferred) {
119 match d
120 .driver()
121 .start(self.clone(), tmp_path.clone(), Arc::clone(&cancel))
122 {
123 Ok(join) => {
124 return Ok(RequestHandle {
125 cancel,
126 join: Some(join),
127 target_path,
128 tmp_path,
129 });
130 }
131 Err(StartError::NoDriverFound) => {
132 saw_any_not_found = true;
133 continue;
134 }
135 Err(StartError::IoError(e)) => {
136 if saw_non_not_found.is_none() {
137 saw_non_not_found = Some(e);
138 }
139 continue;
140 }
141 Err(StartError::Url(msg)) => return Err(StartError::Url(msg)),
142 }
143 }
144
145 if let Some(e) = saw_non_not_found {
146 return Err(StartError::IoError(e));
147 }
148 if saw_any_not_found {
149 return Err(StartError::NoDriverFound);
150 }
151 Err(StartError::NoDriverFound)
152 }
153}
154
155impl Downloader {
156 pub(crate) fn driver(self) -> &'static dyn drivers::Driver {
157 static CURL: drivers::curl::CurlDriver = drivers::curl::CurlDriver;
158 static WGET: drivers::wget::WgetDriver = drivers::wget::WgetDriver;
159 static POWERSHELL: drivers::powershell::PowerShellDriver =
160 drivers::powershell::PowerShellDriver;
161 static OPENSSL: drivers::openssl::OpenSslDriver = drivers::openssl::OpenSslDriver;
162
163 match self {
164 Downloader::Curl => &CURL,
165 Downloader::Wget => &WGET,
166 Downloader::PowerShell => &POWERSHELL,
167 Downloader::OpenSsl => &OPENSSL,
168 }
169 }
170}
171
172#[derive(Debug)]
173pub struct RequestHandle {
174 cancel: Arc<AtomicBool>,
175 join: Option<JoinHandle<Result<DownloadResult, ResponseError>>>,
176 target_path: std::path::PathBuf,
177 tmp_path: std::path::PathBuf,
178}
179
180impl RequestHandle {
181 pub fn cancel(&self) {
182 self.cancel.store(true, Ordering::SeqCst);
183 }
184
185 pub fn join(mut self) -> Result<Response, ResponseError> {
186 let res = match self.join.take().expect("join called once").join() {
187 Ok(r) => r,
188 Err(_) => Err(ResponseError::ThreadPanicked),
189 }?;
190
191 util::finalize_download(&self.tmp_path, &self.target_path, res.content_encoding_gzip)?;
192 Ok(Response {
193 status_code: res.status_code,
194 })
195 }
196}
197
198impl Drop for RequestHandle {
199 fn drop(&mut self) {
200 if self.join.is_some() {
201 self.cancel.store(true, Ordering::SeqCst);
202 let _ = std::fs::remove_file(&self.tmp_path);
203 }
204 }
205}
206
207#[derive(Debug, Clone)]
208pub struct Response {
209 pub status_code: u16,
210}
211
212#[derive(Debug)]
213pub enum StartError {
214 NoDriverFound,
215 IoError(io::Error),
216 Url(String),
217}
218
219impl From<io::Error> for StartError {
220 fn from(value: io::Error) -> Self {
221 Self::IoError(value)
222 }
223}
224
225#[derive(Debug)]
226pub enum ResponseError {
227 Io(io::Error),
228 InvalidUrl,
229 UnsupportedScheme,
230 Cancelled,
231 ThreadPanicked,
232 CommandFailed {
233 program: &'static str,
234 exit_code: Option<i32>,
235 stderr: String,
236 },
237 BadStatusCode(String),
238 GzipFailed {
239 exit_code: Option<i32>,
240 stderr: String,
241 },
242 Start(StartError),
243}
244
245impl From<io::Error> for ResponseError {
246 fn from(value: io::Error) -> Self {
247 Self::Io(value)
248 }
249}
250
251fn candidate_downloaders(preferred: &[Downloader]) -> Vec<Downloader> {
252 if !preferred.is_empty() {
253 return preferred.to_vec();
254 }
255 vec![
256 Downloader::Curl,
257 Downloader::Wget,
258 Downloader::PowerShell,
259 Downloader::OpenSsl,
260 ]
261}