Skip to main content

upstream_rs/providers/http/
http_client.rs

1use anyhow::{Context, Result, bail};
2use chrono::{DateTime, Utc};
3use reqwest::{Client, StatusCode, header};
4use std::collections::HashSet;
5use std::path::Path;
6use tokio::fs::File;
7use tokio::io::AsyncWriteExt;
8
9use crate::models::common::enums::Filetype;
10use crate::utils::filename_parser::parse_filetype;
11
12#[derive(Debug, Clone)]
13pub struct HttpAssetInfo {
14    pub download_url: String,
15    pub name: String,
16    pub size: u64,
17    pub last_modified: Option<DateTime<Utc>>,
18    pub etag: Option<String>,
19}
20
21#[derive(Debug, Clone)]
22pub enum ConditionalProbeResult {
23    NotModified,
24    Asset(HttpAssetInfo),
25}
26
27#[derive(Debug, Clone)]
28pub enum ConditionalDiscoveryResult {
29    NotModified,
30    Assets(Vec<HttpAssetInfo>),
31}
32
33#[derive(Debug, Clone)]
34pub struct HttpClient {
35    client: Client,
36}
37
38impl HttpClient {
39    fn format_http_date(dt: DateTime<Utc>) -> String {
40        dt.format("%a, %d %b %Y %H:%M:%S GMT").to_string()
41    }
42
43    fn add_if_modified_since(
44        mut request: reqwest::RequestBuilder,
45        last_upgraded: Option<DateTime<Utc>>,
46    ) -> reqwest::RequestBuilder {
47        if let Some(ts) = last_upgraded {
48            request = request.header(header::IF_MODIFIED_SINCE, Self::format_http_date(ts));
49        }
50        request
51    }
52
53    fn parse_last_modified(value: Option<&header::HeaderValue>) -> Option<DateTime<Utc>> {
54        let raw = value?.to_str().ok()?;
55        DateTime::parse_from_rfc2822(raw)
56            .ok()
57            .map(|dt| dt.with_timezone(&Utc))
58    }
59
60    fn parse_etag(value: Option<&header::HeaderValue>) -> Option<String> {
61        value
62            .and_then(|v| v.to_str().ok())
63            .map(str::trim)
64            .map(|s| s.trim_matches('"').to_string())
65            .filter(|s| !s.is_empty())
66    }
67
68    pub fn new() -> Result<Self> {
69        let mut headers = header::HeaderMap::new();
70
71        let user_agent = format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"));
72        headers.insert(
73            header::USER_AGENT,
74            header::HeaderValue::from_str(&user_agent)
75                .context("Failed to create user agent header")?,
76        );
77
78        let client = Client::builder()
79            .default_headers(headers)
80            .build()
81            .context("Failed to build HTTP client")?;
82
83        Ok(Self { client })
84    }
85
86    /// Normalize provider inputs so bare hosts/slugs become HTTPS URLs.
87    pub fn normalize_url(url_or_slug: &str) -> String {
88        let raw = url_or_slug.trim();
89        if raw.starts_with("http://") || raw.starts_with("https://") {
90            raw.to_string()
91        } else {
92            format!("https://{}", raw)
93        }
94    }
95
96    /// Extract raw `href` attribute values from HTML without a full DOM parser.
97    fn extract_hrefs(html: &str) -> Vec<String> {
98        let mut hrefs = Vec::new();
99        let lower = html.to_lowercase();
100        let bytes = lower.as_bytes();
101        let mut i = 0_usize;
102
103        while i + 6 < bytes.len() {
104            if &bytes[i..i + 5] != b"href=" {
105                i += 1;
106                continue;
107            }
108            let mut j = i + 5;
109            while j < bytes.len() && bytes[j].is_ascii_whitespace() {
110                j += 1;
111            }
112            if j >= bytes.len() {
113                break;
114            }
115
116            let quote = bytes[j];
117            if quote == b'"' || quote == b'\'' {
118                let start = j + 1;
119                let mut end = start;
120                while end < bytes.len() && bytes[end] != quote {
121                    end += 1;
122                }
123                if end <= html.len() && start <= end {
124                    let href = html[start..end].trim();
125                    if !href.is_empty() {
126                        hrefs.push(href.to_string());
127                    }
128                }
129                i = end.saturating_add(1);
130                continue;
131            }
132
133            i = j.saturating_add(1);
134        }
135
136        hrefs
137    }
138
139    fn to_asset_info(url: &str, headers: &header::HeaderMap) -> HttpAssetInfo {
140        HttpAssetInfo {
141            name: Self::file_name_from_url(url),
142            download_url: url.to_string(),
143            size: headers
144                .get(header::CONTENT_LENGTH)
145                .and_then(|v| v.to_str().ok())
146                .and_then(|s| s.parse::<u64>().ok())
147                .unwrap_or(0),
148            last_modified: Self::parse_last_modified(headers.get(header::LAST_MODIFIED)),
149            etag: Self::parse_etag(headers.get(header::ETAG)),
150        }
151    }
152
153    /// Convert discovered links into unique, non-checksum HTTP assets.
154    fn extract_assets_from_html(base: &reqwest::Url, html: &str) -> Vec<HttpAssetInfo> {
155        let hrefs = Self::extract_hrefs(html);
156
157        let mut seen = HashSet::new();
158        let mut assets = Vec::new();
159        for href in hrefs {
160            if href.starts_with('#')
161                || href.starts_with("javascript:")
162                || href.starts_with("mailto:")
163                || href.starts_with("tel:")
164            {
165                continue;
166            }
167
168            let Ok(joined) = base.join(&href) else {
169                continue;
170            };
171            if joined.scheme() != "http" && joined.scheme() != "https" {
172                continue;
173            }
174
175            let joined_str = joined.to_string();
176            let name = Self::file_name_from_url(&joined_str);
177            if name.is_empty() {
178                continue;
179            }
180
181            if parse_filetype(&name) == Filetype::Checksum {
182                continue;
183            }
184
185            if seen.insert(joined_str.clone()) {
186                assets.push(HttpAssetInfo {
187                    download_url: joined_str,
188                    name,
189                    size: 0,
190                    last_modified: None,
191                    etag: None,
192                });
193            }
194        }
195        assets
196    }
197
198    /// Discover downloadable assets from an HTTP endpoint with optional
199    /// `If-Modified-Since` behavior.
200    pub async fn discover_assets_if_modified_since(
201        &self,
202        url_or_slug: &str,
203        last_upgraded: Option<DateTime<Utc>>,
204    ) -> Result<ConditionalDiscoveryResult> {
205        let url = Self::normalize_url(url_or_slug);
206        let response = Self::add_if_modified_since(self.client.get(&url), last_upgraded)
207            .send()
208            .await
209            .context(format!("Failed to send request to {}", url))?;
210
211        if response.status() == StatusCode::NOT_MODIFIED {
212            return Ok(ConditionalDiscoveryResult::NotModified);
213        }
214
215        response
216            .error_for_status_ref()
217            .context(format!("HTTP server returned error for {}", url))?;
218
219        let final_url = response.url().to_string();
220        let content_type = response
221            .headers()
222            .get(header::CONTENT_TYPE)
223            .and_then(|v| v.to_str().ok())
224            .unwrap_or("")
225            .to_lowercase();
226        let response_headers = response.headers().clone();
227
228        if !content_type.contains("text/html") {
229            return Ok(ConditionalDiscoveryResult::Assets(vec![
230                Self::to_asset_info(&final_url, response.headers()),
231            ]));
232        }
233
234        let base = reqwest::Url::parse(&final_url)
235            .context(format!("Failed to parse URL '{}'", final_url))?;
236        let body = response.text().await.context("Failed to read HTML body")?;
237        let assets = Self::extract_assets_from_html(&base, &body);
238
239        if assets.is_empty() {
240            Ok(ConditionalDiscoveryResult::Assets(vec![
241                Self::to_asset_info(&final_url, &response_headers),
242            ]))
243        } else {
244            Ok(ConditionalDiscoveryResult::Assets(assets))
245        }
246    }
247
248    /// Derive a filename from URL path segments with a safe fallback.
249    pub fn file_name_from_url(url: &str) -> String {
250        let without_fragment = url.split('#').next().unwrap_or(url);
251        let without_query = without_fragment
252            .split('?')
253            .next()
254            .unwrap_or(without_fragment);
255        let candidate = without_query.rsplit('/').next().unwrap_or("").trim();
256
257        if candidate.is_empty() {
258            "download.bin".to_string()
259        } else {
260            candidate.to_string()
261        }
262    }
263
264    pub async fn probe_asset(&self, url_or_slug: &str) -> Result<HttpAssetInfo> {
265        match self
266            .probe_asset_if_modified_since(url_or_slug, None)
267            .await?
268        {
269            ConditionalProbeResult::NotModified => {
270                bail!("Unexpected 304 Not Modified response without conditional timestamp")
271            }
272            ConditionalProbeResult::Asset(asset) => Ok(asset),
273        }
274    }
275
276    pub async fn probe_asset_if_modified_since(
277        &self,
278        url_or_slug: &str,
279        last_upgraded: Option<DateTime<Utc>>,
280    ) -> Result<ConditionalProbeResult> {
281        let url = Self::normalize_url(url_or_slug);
282
283        let head_resp = Self::add_if_modified_since(self.client.head(&url), last_upgraded)
284            .send()
285            .await;
286
287        let (size, last_modified, etag) = match head_resp {
288            Ok(resp) if resp.status() == StatusCode::NOT_MODIFIED => {
289                return Ok(ConditionalProbeResult::NotModified);
290            }
291            Ok(resp) if resp.status().is_success() => {
292                let last_modified =
293                    Self::parse_last_modified(resp.headers().get(header::LAST_MODIFIED));
294                let etag = Self::parse_etag(resp.headers().get(header::ETAG));
295                (resp.content_length().unwrap_or(0), last_modified, etag)
296            }
297            Ok(resp)
298                if resp.status() == StatusCode::METHOD_NOT_ALLOWED
299                    || resp.status() == StatusCode::NOT_IMPLEMENTED =>
300            {
301                let get_resp = Self::add_if_modified_since(self.client.get(&url), last_upgraded)
302                    .send()
303                    .await
304                    .context(format!("Failed to send request to {}", url))?;
305
306                if get_resp.status() == StatusCode::NOT_MODIFIED {
307                    return Ok(ConditionalProbeResult::NotModified);
308                }
309
310                get_resp
311                    .error_for_status_ref()
312                    .context(format!("HTTP server returned error for {}", url))?;
313                let last_modified =
314                    Self::parse_last_modified(get_resp.headers().get(header::LAST_MODIFIED));
315                let etag = Self::parse_etag(get_resp.headers().get(header::ETAG));
316                (get_resp.content_length().unwrap_or(0), last_modified, etag)
317            }
318            Ok(resp) => {
319                bail!("HTTP server returned {} for {}", resp.status(), url);
320            }
321            Err(_) => {
322                let get_resp = Self::add_if_modified_since(self.client.get(&url), last_upgraded)
323                    .send()
324                    .await
325                    .context(format!("Failed to send request to {}", url))?;
326
327                if get_resp.status() == StatusCode::NOT_MODIFIED {
328                    return Ok(ConditionalProbeResult::NotModified);
329                }
330
331                get_resp
332                    .error_for_status_ref()
333                    .context(format!("HTTP server returned error for {}", url))?;
334                let last_modified =
335                    Self::parse_last_modified(get_resp.headers().get(header::LAST_MODIFIED));
336                let etag = Self::parse_etag(get_resp.headers().get(header::ETAG));
337                (get_resp.content_length().unwrap_or(0), last_modified, etag)
338            }
339        };
340
341        Ok(ConditionalProbeResult::Asset(HttpAssetInfo {
342            name: Self::file_name_from_url(&url),
343            download_url: url,
344            size,
345            last_modified,
346            etag,
347        }))
348    }
349
350    pub async fn download_file<F>(
351        &self,
352        url: &str,
353        destination: &Path,
354        progress: &mut Option<F>,
355    ) -> Result<()>
356    where
357        F: FnMut(u64, u64),
358    {
359        let response = self
360            .client
361            .get(url)
362            .send()
363            .await
364            .context(format!("Failed to download from {}", url))?;
365
366        response
367            .error_for_status_ref()
368            .context("Download request failed")?;
369
370        let total_bytes = response.content_length().unwrap_or(0);
371
372        let mut file = File::create(destination)
373            .await
374            .context(format!("Failed to create file at {:?}", destination))?;
375
376        let mut stream = response.bytes_stream();
377        let mut total_read: u64 = 0;
378
379        use futures_util::StreamExt;
380        while let Some(chunk) = stream.next().await {
381            let chunk = chunk.context("Failed to read download chunk")?;
382
383            file.write_all(&chunk)
384                .await
385                .context("Failed to write to file")?;
386
387            total_read += chunk.len() as u64;
388
389            if let Some(cb) = progress.as_mut() {
390                cb(total_read, total_bytes);
391            }
392        }
393
394        file.flush().await.context("Failed to flush file")?;
395
396        if total_bytes > 0 && total_read != total_bytes {
397            bail!(
398                "Download size mismatch: expected {} bytes, got {} bytes",
399                total_bytes,
400                total_read
401            );
402        }
403
404        Ok(())
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::{ConditionalDiscoveryResult, ConditionalProbeResult, HttpClient};
411    use chrono::Utc;
412    use std::io::{BufRead, BufReader, Write};
413    use std::net::TcpListener;
414    use std::path::{Path, PathBuf};
415    use std::sync::mpsc;
416    use std::thread;
417    use std::time::{SystemTime, UNIX_EPOCH};
418    use std::{fs, io};
419
420    fn spawn_test_server<F>(max_requests: usize, handler: F) -> String
421    where
422        F: Fn(&str, &str) -> String + Send + 'static,
423    {
424        let (tx, rx) = mpsc::channel();
425        thread::spawn(move || {
426            let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
427            let addr = listener.local_addr().expect("resolve local addr");
428            tx.send(addr).expect("send test server addr");
429
430            for _ in 0..max_requests {
431                let (mut stream, _) = listener.accept().expect("accept request");
432                let cloned = stream.try_clone().expect("clone stream");
433                let mut reader = BufReader::new(cloned);
434
435                let mut request_line = String::new();
436                reader
437                    .read_line(&mut request_line)
438                    .expect("read request line");
439                let mut parts = request_line.split_whitespace();
440                let method = parts.next().unwrap_or("");
441                let path = parts.next().unwrap_or("/");
442
443                let mut line = String::new();
444                loop {
445                    line.clear();
446                    reader.read_line(&mut line).expect("read request headers");
447                    if line == "\r\n" || line.is_empty() {
448                        break;
449                    }
450                }
451
452                let response = handler(method, path);
453                stream
454                    .write_all(response.as_bytes())
455                    .expect("write response");
456                stream.flush().expect("flush response");
457            }
458        });
459
460        let addr = rx.recv().expect("receive server address");
461        format!("http://{}", addr)
462    }
463
464    fn http_response(status_line: &str, headers: &[(&str, &str)], body: &str) -> String {
465        let mut out = format!("{status_line}\r\n");
466        for (k, v) in headers {
467            out.push_str(&format!("{k}: {v}\r\n"));
468        }
469        out.push_str("\r\n");
470        out.push_str(body);
471        out
472    }
473
474    fn temp_file_path(name: &str) -> PathBuf {
475        let nanos = SystemTime::now()
476            .duration_since(UNIX_EPOCH)
477            .map(|d| d.as_nanos())
478            .unwrap_or(0);
479        std::env::temp_dir().join(format!("upstream-http-test-{name}-{nanos}.bin"))
480    }
481
482    fn cleanup_file(path: &Path) -> io::Result<()> {
483        if path.exists() {
484            fs::remove_file(path)?;
485        }
486        Ok(())
487    }
488
489    #[test]
490    fn normalize_url_and_file_name_from_url_behave_as_expected() {
491        assert_eq!(
492            HttpClient::normalize_url("example.com/a"),
493            "https://example.com/a"
494        );
495        assert_eq!(
496            HttpClient::normalize_url("http://example.com/a"),
497            "http://example.com/a"
498        );
499
500        assert_eq!(
501            HttpClient::file_name_from_url("https://x.invalid/path/tool.tar.gz?x=1#frag"),
502            "tool.tar.gz"
503        );
504        assert_eq!(
505            HttpClient::file_name_from_url("https://x.invalid/path/"),
506            "download.bin"
507        );
508    }
509
510    #[tokio::test]
511    async fn discover_assets_extracts_and_filters_html_links() {
512        let html = r##"
513                <html><body>
514                    <a href="tool-v1.2.3-linux.tar.gz">main</a>
515                    <a href="/downloads/tool-v1.2.3-linux.tar.gz">duplicate</a>
516                    <a href="tool-v1.2.3.sha256">checksum</a>
517                    <a href="mailto:test@example.com">mail</a>
518                    <a href="#anchor">anchor</a>
519                    <a href="https://example.invalid/tool-v1.2.3-macos.zip">mac</a>
520                </body></html>
521            "##;
522        let body = html.to_string();
523        let server = spawn_test_server(1, move |_, _| {
524            http_response(
525                "HTTP/1.1 200 OK",
526                &[
527                    ("Content-Type", "text/html"),
528                    ("Content-Length", &body.len().to_string()),
529                    ("Connection", "close"),
530                ],
531                &body,
532            )
533        });
534        let client = HttpClient::new().expect("client");
535
536        let result = client
537            .discover_assets_if_modified_since(&server, None)
538            .await
539            .expect("discover assets");
540
541        match result {
542            ConditionalDiscoveryResult::NotModified => panic!("unexpected not modified"),
543            ConditionalDiscoveryResult::Assets(assets) => {
544                assert_eq!(assets.len(), 3);
545                assert!(
546                    assets
547                        .iter()
548                        .any(|a| a.name.ends_with("tool-v1.2.3-linux.tar.gz"))
549                );
550                assert!(assets.iter().all(|a| !a.name.ends_with(".sha256")));
551            }
552        }
553    }
554
555    #[tokio::test]
556    async fn probe_asset_if_modified_since_returns_not_modified_on_304() {
557        let server = spawn_test_server(1, move |method, _| {
558            assert_eq!(method, "HEAD");
559            http_response("HTTP/1.1 304 Not Modified", &[("Connection", "close")], "")
560        });
561        let client = HttpClient::new().expect("client");
562
563        let result = client
564            .probe_asset_if_modified_since(&server, Some(Utc::now()))
565            .await
566            .expect("probe");
567        assert!(matches!(result, ConditionalProbeResult::NotModified));
568    }
569
570    #[tokio::test]
571    async fn probe_asset_if_modified_since_falls_back_to_get_on_405_head() {
572        let last_modified = "Tue, 10 Feb 2026 15:04:05 GMT".to_string();
573        let etag = "\"abc123\"".to_string();
574        let server = spawn_test_server(2, move |method, _| match method {
575            "HEAD" => http_response(
576                "HTTP/1.1 405 Method Not Allowed",
577                &[("Connection", "close"), ("Content-Length", "0")],
578                "",
579            ),
580            "GET" => http_response(
581                "HTTP/1.1 200 OK",
582                &[
583                    ("Connection", "close"),
584                    ("Content-Length", "11"),
585                    ("Last-Modified", &last_modified),
586                    ("ETag", &etag),
587                ],
588                "hello world",
589            ),
590            _ => http_response(
591                "HTTP/1.1 500 Internal Server Error",
592                &[("Connection", "close"), ("Content-Length", "0")],
593                "",
594            ),
595        });
596        let client = HttpClient::new().expect("client");
597
598        let result = client
599            .probe_asset_if_modified_since(&format!("{server}/tool-v2.3.4.tar.gz"), None)
600            .await
601            .expect("probe fallback");
602
603        match result {
604            ConditionalProbeResult::NotModified => panic!("unexpected not modified"),
605            ConditionalProbeResult::Asset(asset) => {
606                assert_eq!(asset.size, 11);
607                assert_eq!(asset.etag.as_deref(), Some("abc123"));
608                assert!(asset.last_modified.is_some());
609                assert_eq!(asset.name, "tool-v2.3.4.tar.gz");
610            }
611        }
612    }
613
614    #[tokio::test]
615    async fn download_file_writes_bytes_and_reports_progress() {
616        let body = "stream-body-data".to_string();
617        let len = body.len().to_string();
618        let body_for_server = body.clone();
619        let server = spawn_test_server(1, move |method, _| {
620            assert_eq!(method, "GET");
621            http_response(
622                "HTTP/1.1 200 OK",
623                &[
624                    ("Connection", "close"),
625                    ("Content-Type", "application/octet-stream"),
626                    ("Content-Length", &len),
627                ],
628                &body_for_server,
629            )
630        });
631        let client = HttpClient::new().expect("client");
632        let output = temp_file_path("download");
633        let mut progress = Vec::new();
634        let mut cb = Some(|downloaded: u64, total: u64| {
635            progress.push((downloaded, total));
636        });
637
638        client
639            .download_file(&server, &output, &mut cb)
640            .await
641            .expect("download file");
642
643        assert_eq!(fs::read_to_string(&output).expect("read output file"), body);
644        assert!(!progress.is_empty());
645        assert_eq!(
646            progress.last().copied().expect("final progress"),
647            (body.len() as u64, body.len() as u64)
648        );
649
650        cleanup_file(&output).expect("cleanup output file");
651    }
652}