1use anyhow::{Context, Result, bail};
2use chrono::{DateTime, Utc};
3use reqwest::{Client, StatusCode, header};
4use std::collections::HashSet;
5use std::path::Path;
6use tokio::fs::File;
7use tokio::io::AsyncWriteExt;
8
9use crate::models::common::enums::Filetype;
10use crate::utils::filename_parser::parse_filetype;
11
12#[derive(Debug, Clone)]
13pub struct HttpAssetInfo {
14 pub download_url: String,
15 pub name: String,
16 pub size: u64,
17 pub last_modified: Option<DateTime<Utc>>,
18 pub etag: Option<String>,
19}
20
21#[derive(Debug, Clone)]
22pub enum ConditionalProbeResult {
23 NotModified,
24 Asset(HttpAssetInfo),
25}
26
27#[derive(Debug, Clone)]
28pub enum ConditionalDiscoveryResult {
29 NotModified,
30 Assets(Vec<HttpAssetInfo>),
31}
32
33#[derive(Debug, Clone)]
34pub struct HttpClient {
35 client: Client,
36}
37
38impl HttpClient {
39 fn format_http_date(dt: DateTime<Utc>) -> String {
40 dt.format("%a, %d %b %Y %H:%M:%S GMT").to_string()
41 }
42
43 fn add_if_modified_since(
44 mut request: reqwest::RequestBuilder,
45 last_upgraded: Option<DateTime<Utc>>,
46 ) -> reqwest::RequestBuilder {
47 if let Some(ts) = last_upgraded {
48 request = request.header(header::IF_MODIFIED_SINCE, Self::format_http_date(ts));
49 }
50 request
51 }
52
53 fn parse_last_modified(value: Option<&header::HeaderValue>) -> Option<DateTime<Utc>> {
54 let raw = value?.to_str().ok()?;
55 DateTime::parse_from_rfc2822(raw)
56 .ok()
57 .map(|dt| dt.with_timezone(&Utc))
58 }
59
60 fn parse_etag(value: Option<&header::HeaderValue>) -> Option<String> {
61 value
62 .and_then(|v| v.to_str().ok())
63 .map(str::trim)
64 .map(|s| s.trim_matches('"').to_string())
65 .filter(|s| !s.is_empty())
66 }
67
68 pub fn new() -> Result<Self> {
69 let mut headers = header::HeaderMap::new();
70
71 let user_agent = format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"));
72 headers.insert(
73 header::USER_AGENT,
74 header::HeaderValue::from_str(&user_agent)
75 .context("Failed to create user agent header")?,
76 );
77
78 let client = Client::builder()
79 .default_headers(headers)
80 .build()
81 .context("Failed to build HTTP client")?;
82
83 Ok(Self { client })
84 }
85
86 pub fn normalize_url(url_or_slug: &str) -> String {
88 let raw = url_or_slug.trim();
89 if raw.starts_with("http://") || raw.starts_with("https://") {
90 raw.to_string()
91 } else {
92 format!("https://{}", raw)
93 }
94 }
95
96 fn extract_hrefs(html: &str) -> Vec<String> {
98 let mut hrefs = Vec::new();
99 let lower = html.to_lowercase();
100 let bytes = lower.as_bytes();
101 let mut i = 0_usize;
102
103 while i + 6 < bytes.len() {
104 if &bytes[i..i + 5] != b"href=" {
105 i += 1;
106 continue;
107 }
108 let mut j = i + 5;
109 while j < bytes.len() && bytes[j].is_ascii_whitespace() {
110 j += 1;
111 }
112 if j >= bytes.len() {
113 break;
114 }
115
116 let quote = bytes[j];
117 if quote == b'"' || quote == b'\'' {
118 let start = j + 1;
119 let mut end = start;
120 while end < bytes.len() && bytes[end] != quote {
121 end += 1;
122 }
123 if end <= html.len() && start <= end {
124 let href = html[start..end].trim();
125 if !href.is_empty() {
126 hrefs.push(href.to_string());
127 }
128 }
129 i = end.saturating_add(1);
130 continue;
131 }
132
133 i = j.saturating_add(1);
134 }
135
136 hrefs
137 }
138
139 fn to_asset_info(url: &str, headers: &header::HeaderMap) -> HttpAssetInfo {
140 HttpAssetInfo {
141 name: Self::file_name_from_url(url),
142 download_url: url.to_string(),
143 size: headers
144 .get(header::CONTENT_LENGTH)
145 .and_then(|v| v.to_str().ok())
146 .and_then(|s| s.parse::<u64>().ok())
147 .unwrap_or(0),
148 last_modified: Self::parse_last_modified(headers.get(header::LAST_MODIFIED)),
149 etag: Self::parse_etag(headers.get(header::ETAG)),
150 }
151 }
152
153 fn extract_assets_from_html(base: &reqwest::Url, html: &str) -> Vec<HttpAssetInfo> {
155 let hrefs = Self::extract_hrefs(html);
156
157 let mut seen = HashSet::new();
158 let mut assets = Vec::new();
159 for href in hrefs {
160 if href.starts_with('#')
161 || href.starts_with("javascript:")
162 || href.starts_with("mailto:")
163 || href.starts_with("tel:")
164 {
165 continue;
166 }
167
168 let Ok(joined) = base.join(&href) else {
169 continue;
170 };
171 if joined.scheme() != "http" && joined.scheme() != "https" {
172 continue;
173 }
174
175 let joined_str = joined.to_string();
176 let name = Self::file_name_from_url(&joined_str);
177 if name.is_empty() {
178 continue;
179 }
180
181 if parse_filetype(&name) == Filetype::Checksum {
182 continue;
183 }
184
185 if seen.insert(joined_str.clone()) {
186 assets.push(HttpAssetInfo {
187 download_url: joined_str,
188 name,
189 size: 0,
190 last_modified: None,
191 etag: None,
192 });
193 }
194 }
195 assets
196 }
197
198 pub async fn discover_assets_if_modified_since(
201 &self,
202 url_or_slug: &str,
203 last_upgraded: Option<DateTime<Utc>>,
204 ) -> Result<ConditionalDiscoveryResult> {
205 let url = Self::normalize_url(url_or_slug);
206 let response = Self::add_if_modified_since(self.client.get(&url), last_upgraded)
207 .send()
208 .await
209 .context(format!("Failed to send request to {}", url))?;
210
211 if response.status() == StatusCode::NOT_MODIFIED {
212 return Ok(ConditionalDiscoveryResult::NotModified);
213 }
214
215 response
216 .error_for_status_ref()
217 .context(format!("HTTP server returned error for {}", url))?;
218
219 let final_url = response.url().to_string();
220 let content_type = response
221 .headers()
222 .get(header::CONTENT_TYPE)
223 .and_then(|v| v.to_str().ok())
224 .unwrap_or("")
225 .to_lowercase();
226 let response_headers = response.headers().clone();
227
228 if !content_type.contains("text/html") {
229 return Ok(ConditionalDiscoveryResult::Assets(vec![
230 Self::to_asset_info(&final_url, response.headers()),
231 ]));
232 }
233
234 let base = reqwest::Url::parse(&final_url)
235 .context(format!("Failed to parse URL '{}'", final_url))?;
236 let body = response.text().await.context("Failed to read HTML body")?;
237 let assets = Self::extract_assets_from_html(&base, &body);
238
239 if assets.is_empty() {
240 Ok(ConditionalDiscoveryResult::Assets(vec![
241 Self::to_asset_info(&final_url, &response_headers),
242 ]))
243 } else {
244 Ok(ConditionalDiscoveryResult::Assets(assets))
245 }
246 }
247
248 pub fn file_name_from_url(url: &str) -> String {
250 let without_fragment = url.split('#').next().unwrap_or(url);
251 let without_query = without_fragment
252 .split('?')
253 .next()
254 .unwrap_or(without_fragment);
255 let candidate = without_query.rsplit('/').next().unwrap_or("").trim();
256
257 if candidate.is_empty() {
258 "download.bin".to_string()
259 } else {
260 candidate.to_string()
261 }
262 }
263
264 pub async fn probe_asset(&self, url_or_slug: &str) -> Result<HttpAssetInfo> {
265 match self
266 .probe_asset_if_modified_since(url_or_slug, None)
267 .await?
268 {
269 ConditionalProbeResult::NotModified => {
270 bail!("Unexpected 304 Not Modified response without conditional timestamp")
271 }
272 ConditionalProbeResult::Asset(asset) => Ok(asset),
273 }
274 }
275
276 pub async fn probe_asset_if_modified_since(
277 &self,
278 url_or_slug: &str,
279 last_upgraded: Option<DateTime<Utc>>,
280 ) -> Result<ConditionalProbeResult> {
281 let url = Self::normalize_url(url_or_slug);
282
283 let head_resp = Self::add_if_modified_since(self.client.head(&url), last_upgraded)
284 .send()
285 .await;
286
287 let (size, last_modified, etag) = match head_resp {
288 Ok(resp) if resp.status() == StatusCode::NOT_MODIFIED => {
289 return Ok(ConditionalProbeResult::NotModified);
290 }
291 Ok(resp) if resp.status().is_success() => {
292 let last_modified =
293 Self::parse_last_modified(resp.headers().get(header::LAST_MODIFIED));
294 let etag = Self::parse_etag(resp.headers().get(header::ETAG));
295 (resp.content_length().unwrap_or(0), last_modified, etag)
296 }
297 Ok(resp)
298 if resp.status() == StatusCode::METHOD_NOT_ALLOWED
299 || resp.status() == StatusCode::NOT_IMPLEMENTED =>
300 {
301 let get_resp = Self::add_if_modified_since(self.client.get(&url), last_upgraded)
302 .send()
303 .await
304 .context(format!("Failed to send request to {}", url))?;
305
306 if get_resp.status() == StatusCode::NOT_MODIFIED {
307 return Ok(ConditionalProbeResult::NotModified);
308 }
309
310 get_resp
311 .error_for_status_ref()
312 .context(format!("HTTP server returned error for {}", url))?;
313 let last_modified =
314 Self::parse_last_modified(get_resp.headers().get(header::LAST_MODIFIED));
315 let etag = Self::parse_etag(get_resp.headers().get(header::ETAG));
316 (get_resp.content_length().unwrap_or(0), last_modified, etag)
317 }
318 Ok(resp) => {
319 bail!("HTTP server returned {} for {}", resp.status(), url);
320 }
321 Err(_) => {
322 let get_resp = Self::add_if_modified_since(self.client.get(&url), last_upgraded)
323 .send()
324 .await
325 .context(format!("Failed to send request to {}", url))?;
326
327 if get_resp.status() == StatusCode::NOT_MODIFIED {
328 return Ok(ConditionalProbeResult::NotModified);
329 }
330
331 get_resp
332 .error_for_status_ref()
333 .context(format!("HTTP server returned error for {}", url))?;
334 let last_modified =
335 Self::parse_last_modified(get_resp.headers().get(header::LAST_MODIFIED));
336 let etag = Self::parse_etag(get_resp.headers().get(header::ETAG));
337 (get_resp.content_length().unwrap_or(0), last_modified, etag)
338 }
339 };
340
341 Ok(ConditionalProbeResult::Asset(HttpAssetInfo {
342 name: Self::file_name_from_url(&url),
343 download_url: url,
344 size,
345 last_modified,
346 etag,
347 }))
348 }
349
350 pub async fn download_file<F>(
351 &self,
352 url: &str,
353 destination: &Path,
354 progress: &mut Option<F>,
355 ) -> Result<()>
356 where
357 F: FnMut(u64, u64),
358 {
359 let response = self
360 .client
361 .get(url)
362 .send()
363 .await
364 .context(format!("Failed to download from {}", url))?;
365
366 response
367 .error_for_status_ref()
368 .context("Download request failed")?;
369
370 let total_bytes = response.content_length().unwrap_or(0);
371
372 let mut file = File::create(destination)
373 .await
374 .context(format!("Failed to create file at {:?}", destination))?;
375
376 let mut stream = response.bytes_stream();
377 let mut total_read: u64 = 0;
378
379 use futures_util::StreamExt;
380 while let Some(chunk) = stream.next().await {
381 let chunk = chunk.context("Failed to read download chunk")?;
382
383 file.write_all(&chunk)
384 .await
385 .context("Failed to write to file")?;
386
387 total_read += chunk.len() as u64;
388
389 if let Some(cb) = progress.as_mut() {
390 cb(total_read, total_bytes);
391 }
392 }
393
394 file.flush().await.context("Failed to flush file")?;
395
396 if total_bytes > 0 && total_read != total_bytes {
397 bail!(
398 "Download size mismatch: expected {} bytes, got {} bytes",
399 total_bytes,
400 total_read
401 );
402 }
403
404 Ok(())
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::{ConditionalDiscoveryResult, ConditionalProbeResult, HttpClient};
411 use chrono::Utc;
412 use std::io::{BufRead, BufReader, Write};
413 use std::net::TcpListener;
414 use std::path::{Path, PathBuf};
415 use std::sync::mpsc;
416 use std::thread;
417 use std::time::{SystemTime, UNIX_EPOCH};
418 use std::{fs, io};
419
420 fn spawn_test_server<F>(max_requests: usize, handler: F) -> String
421 where
422 F: Fn(&str, &str) -> String + Send + 'static,
423 {
424 let (tx, rx) = mpsc::channel();
425 thread::spawn(move || {
426 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
427 let addr = listener.local_addr().expect("resolve local addr");
428 tx.send(addr).expect("send test server addr");
429
430 for _ in 0..max_requests {
431 let (mut stream, _) = listener.accept().expect("accept request");
432 let cloned = stream.try_clone().expect("clone stream");
433 let mut reader = BufReader::new(cloned);
434
435 let mut request_line = String::new();
436 reader
437 .read_line(&mut request_line)
438 .expect("read request line");
439 let mut parts = request_line.split_whitespace();
440 let method = parts.next().unwrap_or("");
441 let path = parts.next().unwrap_or("/");
442
443 let mut line = String::new();
444 loop {
445 line.clear();
446 reader.read_line(&mut line).expect("read request headers");
447 if line == "\r\n" || line.is_empty() {
448 break;
449 }
450 }
451
452 let response = handler(method, path);
453 stream
454 .write_all(response.as_bytes())
455 .expect("write response");
456 stream.flush().expect("flush response");
457 }
458 });
459
460 let addr = rx.recv().expect("receive server address");
461 format!("http://{}", addr)
462 }
463
464 fn http_response(status_line: &str, headers: &[(&str, &str)], body: &str) -> String {
465 let mut out = format!("{status_line}\r\n");
466 for (k, v) in headers {
467 out.push_str(&format!("{k}: {v}\r\n"));
468 }
469 out.push_str("\r\n");
470 out.push_str(body);
471 out
472 }
473
474 fn temp_file_path(name: &str) -> PathBuf {
475 let nanos = SystemTime::now()
476 .duration_since(UNIX_EPOCH)
477 .map(|d| d.as_nanos())
478 .unwrap_or(0);
479 std::env::temp_dir().join(format!("upstream-http-test-{name}-{nanos}.bin"))
480 }
481
482 fn cleanup_file(path: &Path) -> io::Result<()> {
483 if path.exists() {
484 fs::remove_file(path)?;
485 }
486 Ok(())
487 }
488
489 #[test]
490 fn normalize_url_and_file_name_from_url_behave_as_expected() {
491 assert_eq!(
492 HttpClient::normalize_url("example.com/a"),
493 "https://example.com/a"
494 );
495 assert_eq!(
496 HttpClient::normalize_url("http://example.com/a"),
497 "http://example.com/a"
498 );
499
500 assert_eq!(
501 HttpClient::file_name_from_url("https://x.invalid/path/tool.tar.gz?x=1#frag"),
502 "tool.tar.gz"
503 );
504 assert_eq!(
505 HttpClient::file_name_from_url("https://x.invalid/path/"),
506 "download.bin"
507 );
508 }
509
510 #[tokio::test]
511 async fn discover_assets_extracts_and_filters_html_links() {
512 let html = r##"
513 <html><body>
514 <a href="tool-v1.2.3-linux.tar.gz">main</a>
515 <a href="/downloads/tool-v1.2.3-linux.tar.gz">duplicate</a>
516 <a href="tool-v1.2.3.sha256">checksum</a>
517 <a href="mailto:test@example.com">mail</a>
518 <a href="#anchor">anchor</a>
519 <a href="https://example.invalid/tool-v1.2.3-macos.zip">mac</a>
520 </body></html>
521 "##;
522 let body = html.to_string();
523 let server = spawn_test_server(1, move |_, _| {
524 http_response(
525 "HTTP/1.1 200 OK",
526 &[
527 ("Content-Type", "text/html"),
528 ("Content-Length", &body.len().to_string()),
529 ("Connection", "close"),
530 ],
531 &body,
532 )
533 });
534 let client = HttpClient::new().expect("client");
535
536 let result = client
537 .discover_assets_if_modified_since(&server, None)
538 .await
539 .expect("discover assets");
540
541 match result {
542 ConditionalDiscoveryResult::NotModified => panic!("unexpected not modified"),
543 ConditionalDiscoveryResult::Assets(assets) => {
544 assert_eq!(assets.len(), 3);
545 assert!(
546 assets
547 .iter()
548 .any(|a| a.name.ends_with("tool-v1.2.3-linux.tar.gz"))
549 );
550 assert!(assets.iter().all(|a| !a.name.ends_with(".sha256")));
551 }
552 }
553 }
554
555 #[tokio::test]
556 async fn probe_asset_if_modified_since_returns_not_modified_on_304() {
557 let server = spawn_test_server(1, move |method, _| {
558 assert_eq!(method, "HEAD");
559 http_response("HTTP/1.1 304 Not Modified", &[("Connection", "close")], "")
560 });
561 let client = HttpClient::new().expect("client");
562
563 let result = client
564 .probe_asset_if_modified_since(&server, Some(Utc::now()))
565 .await
566 .expect("probe");
567 assert!(matches!(result, ConditionalProbeResult::NotModified));
568 }
569
570 #[tokio::test]
571 async fn probe_asset_if_modified_since_falls_back_to_get_on_405_head() {
572 let last_modified = "Tue, 10 Feb 2026 15:04:05 GMT".to_string();
573 let etag = "\"abc123\"".to_string();
574 let server = spawn_test_server(2, move |method, _| match method {
575 "HEAD" => http_response(
576 "HTTP/1.1 405 Method Not Allowed",
577 &[("Connection", "close"), ("Content-Length", "0")],
578 "",
579 ),
580 "GET" => http_response(
581 "HTTP/1.1 200 OK",
582 &[
583 ("Connection", "close"),
584 ("Content-Length", "11"),
585 ("Last-Modified", &last_modified),
586 ("ETag", &etag),
587 ],
588 "hello world",
589 ),
590 _ => http_response(
591 "HTTP/1.1 500 Internal Server Error",
592 &[("Connection", "close"), ("Content-Length", "0")],
593 "",
594 ),
595 });
596 let client = HttpClient::new().expect("client");
597
598 let result = client
599 .probe_asset_if_modified_since(&format!("{server}/tool-v2.3.4.tar.gz"), None)
600 .await
601 .expect("probe fallback");
602
603 match result {
604 ConditionalProbeResult::NotModified => panic!("unexpected not modified"),
605 ConditionalProbeResult::Asset(asset) => {
606 assert_eq!(asset.size, 11);
607 assert_eq!(asset.etag.as_deref(), Some("abc123"));
608 assert!(asset.last_modified.is_some());
609 assert_eq!(asset.name, "tool-v2.3.4.tar.gz");
610 }
611 }
612 }
613
614 #[tokio::test]
615 async fn download_file_writes_bytes_and_reports_progress() {
616 let body = "stream-body-data".to_string();
617 let len = body.len().to_string();
618 let body_for_server = body.clone();
619 let server = spawn_test_server(1, move |method, _| {
620 assert_eq!(method, "GET");
621 http_response(
622 "HTTP/1.1 200 OK",
623 &[
624 ("Connection", "close"),
625 ("Content-Type", "application/octet-stream"),
626 ("Content-Length", &len),
627 ],
628 &body_for_server,
629 )
630 });
631 let client = HttpClient::new().expect("client");
632 let output = temp_file_path("download");
633 let mut progress = Vec::new();
634 let mut cb = Some(|downloaded: u64, total: u64| {
635 progress.push((downloaded, total));
636 });
637
638 client
639 .download_file(&server, &output, &mut cb)
640 .await
641 .expect("download file");
642
643 assert_eq!(fs::read_to_string(&output).expect("read output file"), body);
644 assert!(!progress.is_empty());
645 assert_eq!(
646 progress.last().copied().expect("final progress"),
647 (body.len() as u64, body.len() as u64)
648 );
649
650 cleanup_file(&output).expect("cleanup output file");
651 }
652}