Skip to main content

upstream_rs/providers/http/
http_client.rs

1use anyhow::{Context, Result, bail};
2use chrono::{DateTime, Utc};
3use reqwest::{Client, StatusCode, header};
4use std::collections::HashSet;
5use std::path::Path;
6
7use crate::models::common::enums::Filetype;
8use crate::models::upstream::DownloadConfig;
9use crate::providers::download_handler;
10use crate::utils::filename_parser::parse_filetype;
11
12#[derive(Debug, Clone)]
13pub struct HttpAssetInfo {
14    pub download_url: String,
15    pub name: String,
16    pub size: u64,
17    pub last_modified: Option<DateTime<Utc>>,
18    pub etag: Option<String>,
19}
20
21#[derive(Debug, Clone)]
22pub enum ConditionalProbeResult {
23    NotModified,
24    Asset(HttpAssetInfo),
25}
26
27#[derive(Debug, Clone)]
28pub enum ConditionalDiscoveryResult {
29    NotModified,
30    Assets(Vec<HttpAssetInfo>),
31}
32
33#[derive(Debug, Clone)]
34pub struct HttpClient {
35    client: Client,
36    download_config: DownloadConfig,
37}
38
39impl HttpClient {
40    fn format_http_date(dt: DateTime<Utc>) -> String {
41        dt.format("%a, %d %b %Y %H:%M:%S GMT").to_string()
42    }
43
44    fn add_if_modified_since(
45        mut request: reqwest::RequestBuilder,
46        last_upgraded: Option<DateTime<Utc>>,
47    ) -> reqwest::RequestBuilder {
48        if let Some(ts) = last_upgraded {
49            request = request.header(header::IF_MODIFIED_SINCE, Self::format_http_date(ts));
50        }
51        request
52    }
53
54    fn parse_last_modified(value: Option<&header::HeaderValue>) -> Option<DateTime<Utc>> {
55        let raw = value?.to_str().ok()?;
56        DateTime::parse_from_rfc2822(raw)
57            .ok()
58            .map(|dt| dt.with_timezone(&Utc))
59    }
60
61    fn parse_etag(value: Option<&header::HeaderValue>) -> Option<String> {
62        value
63            .and_then(|v| v.to_str().ok())
64            .map(str::trim)
65            .map(|s| s.trim_matches('"').to_string())
66            .filter(|s| !s.is_empty())
67    }
68
69    fn attribute_has_boundary(html: &str, index: usize, attribute: &str) -> bool {
70        let bytes = html.as_bytes();
71        let valid_start = index == 0
72            || bytes
73                .get(index.saturating_sub(1))
74                .map(|b| !b.is_ascii_alphanumeric() && *b != b'-')
75                .unwrap_or(true);
76        let end = index + attribute.len();
77        let valid_end = bytes
78            .get(end)
79            .map(|b| *b == b'=' || b.is_ascii_whitespace())
80            .unwrap_or(false);
81
82        valid_start && valid_end
83    }
84
85    pub fn new() -> Result<Self> {
86        Self::new_with_download_config(DownloadConfig::default())
87    }
88
89    pub fn new_with_download_config(download_config: DownloadConfig) -> Result<Self> {
90        let mut headers = header::HeaderMap::new();
91
92        let user_agent = format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"));
93        headers.insert(
94            header::USER_AGENT,
95            header::HeaderValue::from_str(&user_agent)
96                .context("Failed to create user agent header")?,
97        );
98
99        let client = Client::builder()
100            .default_headers(headers)
101            .build()
102            .context("Failed to build HTTP client")?;
103
104        Ok(Self {
105            client,
106            download_config,
107        })
108    }
109
110    /// Normalize provider inputs so bare hosts/slugs become HTTPS URLs.
111    pub fn normalize_url(url_or_slug: &str) -> String {
112        let raw = url_or_slug.trim();
113        if raw.starts_with("http://") || raw.starts_with("https://") {
114            raw.to_string()
115        } else {
116            format!("https://{}", raw)
117        }
118    }
119
120    /// Extract likely download URL attribute values from HTML without a full DOM parser.
121    fn extract_link_values(html: &str) -> Vec<String> {
122        let attributes = [
123            "href",
124            "src",
125            "data-href",
126            "data-url",
127            "data-download",
128            "data-download-url",
129        ];
130        let mut values = Vec::new();
131        let lower = html.to_lowercase();
132        let bytes = lower.as_bytes();
133        let mut i = 0_usize;
134
135        while i < bytes.len() {
136            let Some((attribute, attr_offset)) = attributes
137                .iter()
138                .filter_map(|attribute| {
139                    lower[i..]
140                        .find(attribute)
141                        .map(|offset| (*attribute, offset))
142                })
143                .min_by(|(left_attr, left_offset), (right_attr, right_offset)| {
144                    left_offset
145                        .cmp(right_offset)
146                        .then_with(|| right_attr.len().cmp(&left_attr.len()))
147                })
148            else {
149                break;
150            };
151
152            i += attr_offset;
153            if !Self::attribute_has_boundary(&lower, i, attribute) {
154                i += 1;
155                continue;
156            }
157
158            let mut j = i + attribute.len();
159            while j < bytes.len() && bytes[j].is_ascii_whitespace() {
160                j += 1;
161            }
162            if j >= bytes.len() || bytes[j] != b'=' {
163                i += 1;
164                continue;
165            }
166            j += 1;
167            while j < bytes.len() && bytes[j].is_ascii_whitespace() {
168                j += 1;
169            }
170            if j >= bytes.len() {
171                break;
172            }
173
174            let quote = bytes[j];
175            if quote == b'"' || quote == b'\'' {
176                let start = j + 1;
177                let mut end = start;
178                while end < bytes.len() && bytes[end] != quote {
179                    end += 1;
180                }
181                if end <= html.len() && start <= end {
182                    let href = html[start..end].trim();
183                    if !href.is_empty() {
184                        values.push(href.to_string());
185                    }
186                }
187                i = end.saturating_add(1);
188                continue;
189            }
190
191            let start = j;
192            let mut end = start;
193            while end < bytes.len() && !bytes[end].is_ascii_whitespace() && bytes[end] != b'>' {
194                end += 1;
195            }
196            if end <= html.len() && start < end {
197                let href = html[start..end].trim();
198                if !href.is_empty() {
199                    values.push(href.to_string());
200                }
201            }
202
203            i = j.saturating_add(1);
204        }
205
206        values
207    }
208
209    fn to_asset_info(url: &str, headers: &header::HeaderMap) -> HttpAssetInfo {
210        HttpAssetInfo {
211            name: Self::file_name_from_url(url),
212            download_url: url.to_string(),
213            size: headers
214                .get(header::CONTENT_LENGTH)
215                .and_then(|v| v.to_str().ok())
216                .and_then(|s| s.parse::<u64>().ok())
217                .unwrap_or(0),
218            last_modified: Self::parse_last_modified(headers.get(header::LAST_MODIFIED)),
219            etag: Self::parse_etag(headers.get(header::ETAG)),
220        }
221    }
222
223    /// Convert discovered links into unique, non-checksum HTTP assets.
224    fn extract_assets_from_html(
225        base: &reqwest::Url,
226        html: &str,
227        page_headers: &header::HeaderMap,
228    ) -> Vec<HttpAssetInfo> {
229        let hrefs = Self::extract_link_values(html);
230        let page_last_modified = Self::parse_last_modified(page_headers.get(header::LAST_MODIFIED));
231        let page_etag = Self::parse_etag(page_headers.get(header::ETAG));
232
233        let mut seen = HashSet::new();
234        let mut assets = Vec::new();
235        for href in hrefs {
236            if href.starts_with('#')
237                || href.starts_with("javascript:")
238                || href.starts_with("mailto:")
239                || href.starts_with("tel:")
240            {
241                continue;
242            }
243
244            let Ok(joined) = base.join(&href) else {
245                continue;
246            };
247            if joined.scheme() != "http" && joined.scheme() != "https" {
248                continue;
249            }
250
251            let joined_str = joined.to_string();
252            let name = Self::file_name_from_url(&joined_str);
253            if name.is_empty() {
254                continue;
255            }
256
257            if parse_filetype(&name) == Filetype::Checksum {
258                continue;
259            }
260
261            if seen.insert(joined_str.clone()) {
262                assets.push(HttpAssetInfo {
263                    download_url: joined_str,
264                    name,
265                    size: 0,
266                    last_modified: page_last_modified,
267                    etag: page_etag.clone(),
268                });
269            }
270        }
271        assets
272    }
273
274    /// Discover downloadable assets from an HTTP endpoint with optional
275    /// `If-Modified-Since` behavior.
276    pub async fn discover_assets_if_modified_since(
277        &self,
278        url_or_slug: &str,
279        last_upgraded: Option<DateTime<Utc>>,
280    ) -> Result<ConditionalDiscoveryResult> {
281        let url = Self::normalize_url(url_or_slug);
282        let response = Self::add_if_modified_since(self.client.get(&url), last_upgraded)
283            .send()
284            .await
285            .context(format!("Failed to send request to {}", url))?;
286
287        if response.status() == StatusCode::NOT_MODIFIED {
288            return Ok(ConditionalDiscoveryResult::NotModified);
289        }
290
291        response
292            .error_for_status_ref()
293            .context(format!("HTTP server returned error for {}", url))?;
294
295        let final_url = response.url().to_string();
296        let content_type = response
297            .headers()
298            .get(header::CONTENT_TYPE)
299            .and_then(|v| v.to_str().ok())
300            .unwrap_or("")
301            .to_lowercase();
302        let response_headers = response.headers().clone();
303
304        if !content_type.contains("text/html") {
305            return Ok(ConditionalDiscoveryResult::Assets(vec![
306                Self::to_asset_info(&final_url, response.headers()),
307            ]));
308        }
309
310        let base = reqwest::Url::parse(&final_url)
311            .context(format!("Failed to parse URL '{}'", final_url))?;
312        let body = response.text().await.context("Failed to read HTML body")?;
313        let assets = Self::extract_assets_from_html(&base, &body, &response_headers);
314
315        if assets.is_empty() {
316            Ok(ConditionalDiscoveryResult::Assets(vec![
317                Self::to_asset_info(&final_url, &response_headers),
318            ]))
319        } else {
320            Ok(ConditionalDiscoveryResult::Assets(assets))
321        }
322    }
323
324    /// Derive a filename from URL path segments with a safe fallback.
325    pub fn file_name_from_url(url: &str) -> String {
326        let without_fragment = url.split('#').next().unwrap_or(url);
327        let without_query = without_fragment
328            .split('?')
329            .next()
330            .unwrap_or(without_fragment);
331        let candidate = without_query.rsplit('/').next().unwrap_or("").trim();
332
333        if candidate.is_empty() {
334            "download.bin".to_string()
335        } else {
336            candidate.to_string()
337        }
338    }
339
340    pub async fn probe_asset(&self, url_or_slug: &str) -> Result<HttpAssetInfo> {
341        match self
342            .probe_asset_if_modified_since(url_or_slug, None)
343            .await?
344        {
345            ConditionalProbeResult::NotModified => {
346                bail!("Unexpected 304 Not Modified response without conditional timestamp")
347            }
348            ConditionalProbeResult::Asset(asset) => Ok(asset),
349        }
350    }
351
352    pub async fn probe_asset_if_modified_since(
353        &self,
354        url_or_slug: &str,
355        last_upgraded: Option<DateTime<Utc>>,
356    ) -> Result<ConditionalProbeResult> {
357        let url = Self::normalize_url(url_or_slug);
358
359        let head_resp = Self::add_if_modified_since(self.client.head(&url), last_upgraded)
360            .send()
361            .await;
362
363        let (size, last_modified, etag) = match head_resp {
364            Ok(resp) if resp.status() == StatusCode::NOT_MODIFIED => {
365                return Ok(ConditionalProbeResult::NotModified);
366            }
367            Ok(resp) if resp.status().is_success() => {
368                let last_modified =
369                    Self::parse_last_modified(resp.headers().get(header::LAST_MODIFIED));
370                let etag = Self::parse_etag(resp.headers().get(header::ETAG));
371                (resp.content_length().unwrap_or(0), last_modified, etag)
372            }
373            Ok(resp)
374                if resp.status() == StatusCode::METHOD_NOT_ALLOWED
375                    || resp.status() == StatusCode::NOT_IMPLEMENTED =>
376            {
377                let get_resp = Self::add_if_modified_since(self.client.get(&url), last_upgraded)
378                    .send()
379                    .await
380                    .context(format!("Failed to send request to {}", url))?;
381
382                if get_resp.status() == StatusCode::NOT_MODIFIED {
383                    return Ok(ConditionalProbeResult::NotModified);
384                }
385
386                get_resp
387                    .error_for_status_ref()
388                    .context(format!("HTTP server returned error for {}", url))?;
389                let last_modified =
390                    Self::parse_last_modified(get_resp.headers().get(header::LAST_MODIFIED));
391                let etag = Self::parse_etag(get_resp.headers().get(header::ETAG));
392                (get_resp.content_length().unwrap_or(0), last_modified, etag)
393            }
394            Ok(resp) => {
395                bail!("HTTP server returned {} for {}", resp.status(), url);
396            }
397            Err(_) => {
398                let get_resp = Self::add_if_modified_since(self.client.get(&url), last_upgraded)
399                    .send()
400                    .await
401                    .context(format!("Failed to send request to {}", url))?;
402
403                if get_resp.status() == StatusCode::NOT_MODIFIED {
404                    return Ok(ConditionalProbeResult::NotModified);
405                }
406
407                get_resp
408                    .error_for_status_ref()
409                    .context(format!("HTTP server returned error for {}", url))?;
410                let last_modified =
411                    Self::parse_last_modified(get_resp.headers().get(header::LAST_MODIFIED));
412                let etag = Self::parse_etag(get_resp.headers().get(header::ETAG));
413                (get_resp.content_length().unwrap_or(0), last_modified, etag)
414            }
415        };
416
417        Ok(ConditionalProbeResult::Asset(HttpAssetInfo {
418            name: Self::file_name_from_url(&url),
419            download_url: url,
420            size,
421            last_modified,
422            etag,
423        }))
424    }
425
426    pub async fn download_file<F>(
427        &self,
428        url: &str,
429        destination: &Path,
430        progress: &mut Option<F>,
431    ) -> Result<()>
432    where
433        F: FnMut(u64, u64),
434    {
435        download_handler::download_file_with_config(
436            &self.client,
437            url,
438            destination,
439            progress,
440            self.download_config,
441        )
442        .await
443    }
444}
445
446#[cfg(test)]
447mod tests {
448    use super::{ConditionalDiscoveryResult, ConditionalProbeResult, HttpClient};
449    use chrono::Utc;
450    use std::io::{BufRead, BufReader, Write};
451    use std::net::TcpListener;
452    use std::path::{Path, PathBuf};
453    use std::sync::mpsc;
454    use std::thread;
455    use std::time::{SystemTime, UNIX_EPOCH};
456    use std::{fs, io};
457
458    fn spawn_test_server<F>(max_requests: usize, handler: F) -> String
459    where
460        F: Fn(&str, &str) -> String + Send + 'static,
461    {
462        let (tx, rx) = mpsc::channel();
463        thread::spawn(move || {
464            let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
465            let addr = listener.local_addr().expect("resolve local addr");
466            tx.send(addr).expect("send test server addr");
467
468            for _ in 0..max_requests {
469                let (mut stream, _) = listener.accept().expect("accept request");
470                let cloned = stream.try_clone().expect("clone stream");
471                let mut reader = BufReader::new(cloned);
472
473                let mut request_line = String::new();
474                reader
475                    .read_line(&mut request_line)
476                    .expect("read request line");
477                let mut parts = request_line.split_whitespace();
478                let method = parts.next().unwrap_or("");
479                let path = parts.next().unwrap_or("/");
480
481                let mut line = String::new();
482                loop {
483                    line.clear();
484                    reader.read_line(&mut line).expect("read request headers");
485                    if line == "\r\n" || line.is_empty() {
486                        break;
487                    }
488                }
489
490                let response = handler(method, path);
491                stream
492                    .write_all(response.as_bytes())
493                    .expect("write response");
494                stream.flush().expect("flush response");
495            }
496        });
497
498        let addr = rx.recv().expect("receive server address");
499        format!("http://{}", addr)
500    }
501
502    fn http_response(status_line: &str, headers: &[(&str, &str)], body: &str) -> String {
503        let mut out = format!("{status_line}\r\n");
504        for (k, v) in headers {
505            out.push_str(&format!("{k}: {v}\r\n"));
506        }
507        out.push_str("\r\n");
508        out.push_str(body);
509        out
510    }
511
512    fn temp_file_path(name: &str) -> PathBuf {
513        let nanos = SystemTime::now()
514            .duration_since(UNIX_EPOCH)
515            .map(|d| d.as_nanos())
516            .unwrap_or(0);
517        std::env::temp_dir().join(format!("upstream-http-test-{name}-{nanos}.bin"))
518    }
519
520    fn cleanup_file(path: &Path) -> io::Result<()> {
521        if path.exists() {
522            fs::remove_file(path)?;
523        }
524        Ok(())
525    }
526
527    #[test]
528    fn normalize_url_and_file_name_from_url_behave_as_expected() {
529        assert_eq!(
530            HttpClient::normalize_url("example.com/a"),
531            "https://example.com/a"
532        );
533        assert_eq!(
534            HttpClient::normalize_url("http://example.com/a"),
535            "http://example.com/a"
536        );
537
538        assert_eq!(
539            HttpClient::file_name_from_url("https://x.invalid/path/tool.tar.gz?x=1#frag"),
540            "tool.tar.gz"
541        );
542        assert_eq!(
543            HttpClient::file_name_from_url("https://x.invalid/path/"),
544            "download.bin"
545        );
546    }
547
548    #[tokio::test]
549    async fn discover_assets_extracts_and_filters_html_links() {
550        let html = r##"
551                <html><body>
552                    <a href="tool-v1.2.3-linux.tar.gz">main</a>
553                    <a href="/downloads/tool-v1.2.3-linux.tar.gz">duplicate</a>
554                    <a href="tool-v1.2.3.sha256">checksum</a>
555                    <a href="mailto:test@example.com">mail</a>
556                    <a href="#anchor">anchor</a>
557                    <a href="https://example.invalid/tool-v1.2.3-macos.zip">mac</a>
558                    <button data-download-url="/tool-v1.2.3-windows.zip">win</button>
559                </body></html>
560            "##;
561        let body = html.to_string();
562        let last_modified = "Tue, 10 Feb 2026 15:04:05 GMT".to_string();
563        let server = spawn_test_server(1, move |_, _| {
564            http_response(
565                "HTTP/1.1 200 OK",
566                &[
567                    ("Content-Type", "text/html"),
568                    ("Last-Modified", &last_modified),
569                    ("Content-Length", &body.len().to_string()),
570                    ("Connection", "close"),
571                ],
572                &body,
573            )
574        });
575        let client = HttpClient::new().expect("client");
576
577        let result = client
578            .discover_assets_if_modified_since(&server, None)
579            .await
580            .expect("discover assets");
581
582        match result {
583            ConditionalDiscoveryResult::NotModified => panic!("unexpected not modified"),
584            ConditionalDiscoveryResult::Assets(assets) => {
585                assert_eq!(assets.len(), 4);
586                assert!(
587                    assets
588                        .iter()
589                        .any(|a| a.name.ends_with("tool-v1.2.3-linux.tar.gz"))
590                );
591                assert!(assets.iter().all(|a| !a.name.ends_with(".sha256")));
592                assert!(assets.iter().all(|a| a.last_modified.is_some()));
593                assert!(
594                    assets
595                        .iter()
596                        .any(|a| a.name.ends_with("tool-v1.2.3-windows.zip"))
597                );
598            }
599        }
600    }
601
602    #[test]
603    fn extract_link_values_accepts_spaced_and_unquoted_attributes() {
604        let html = r#"
605            <a href = "tool-a.zip">quoted with spaces</a>
606            <a HREF='tool-b.tar.gz'>uppercase single quoted</a>
607            <a href=tool-c.7z>unquoted</a>
608            <button data-download-url = /tool-d.zip>data attr</button>
609        "#;
610
611        let values = HttpClient::extract_link_values(html);
612
613        assert!(values.contains(&"tool-a.zip".to_string()));
614        assert!(values.contains(&"tool-b.tar.gz".to_string()));
615        assert!(values.contains(&"tool-c.7z".to_string()));
616        assert!(values.contains(&"/tool-d.zip".to_string()));
617    }
618
619    #[tokio::test]
620    async fn probe_asset_if_modified_since_returns_not_modified_on_304() {
621        let server = spawn_test_server(1, move |method, _| {
622            assert_eq!(method, "HEAD");
623            http_response("HTTP/1.1 304 Not Modified", &[("Connection", "close")], "")
624        });
625        let client = HttpClient::new().expect("client");
626
627        let result = client
628            .probe_asset_if_modified_since(&server, Some(Utc::now()))
629            .await
630            .expect("probe");
631        assert!(matches!(result, ConditionalProbeResult::NotModified));
632    }
633
634    #[tokio::test]
635    async fn probe_asset_if_modified_since_falls_back_to_get_on_405_head() {
636        let last_modified = "Tue, 10 Feb 2026 15:04:05 GMT".to_string();
637        let etag = "\"abc123\"".to_string();
638        let server = spawn_test_server(2, move |method, _| match method {
639            "HEAD" => http_response(
640                "HTTP/1.1 405 Method Not Allowed",
641                &[("Connection", "close"), ("Content-Length", "0")],
642                "",
643            ),
644            "GET" => http_response(
645                "HTTP/1.1 200 OK",
646                &[
647                    ("Connection", "close"),
648                    ("Content-Length", "11"),
649                    ("Last-Modified", &last_modified),
650                    ("ETag", &etag),
651                ],
652                "hello world",
653            ),
654            _ => http_response(
655                "HTTP/1.1 500 Internal Server Error",
656                &[("Connection", "close"), ("Content-Length", "0")],
657                "",
658            ),
659        });
660        let client = HttpClient::new().expect("client");
661
662        let result = client
663            .probe_asset_if_modified_since(&format!("{server}/tool-v2.3.4.tar.gz"), None)
664            .await
665            .expect("probe fallback");
666
667        match result {
668            ConditionalProbeResult::NotModified => panic!("unexpected not modified"),
669            ConditionalProbeResult::Asset(asset) => {
670                assert_eq!(asset.size, 11);
671                assert_eq!(asset.etag.as_deref(), Some("abc123"));
672                assert!(asset.last_modified.is_some());
673                assert_eq!(asset.name, "tool-v2.3.4.tar.gz");
674            }
675        }
676    }
677
678    #[tokio::test]
679    async fn download_file_writes_bytes_and_reports_progress() {
680        let body = "stream-body-data".to_string();
681        let len = body.len().to_string();
682        let body_for_server = body.clone();
683        let server = spawn_test_server(1, move |method, _| {
684            assert_eq!(method, "GET");
685            http_response(
686                "HTTP/1.1 200 OK",
687                &[
688                    ("Connection", "close"),
689                    ("Content-Type", "application/octet-stream"),
690                    ("Content-Length", &len),
691                ],
692                &body_for_server,
693            )
694        });
695        let client = HttpClient::new().expect("client");
696        let output = temp_file_path("download");
697        let mut progress = Vec::new();
698        let mut cb = Some(|downloaded: u64, total: u64| {
699            progress.push((downloaded, total));
700        });
701
702        client
703            .download_file(&server, &output, &mut cb)
704            .await
705            .expect("download file");
706
707        assert_eq!(fs::read_to_string(&output).expect("read output file"), body);
708        assert!(!progress.is_empty());
709        assert_eq!(
710            progress.last().copied().expect("final progress"),
711            (body.len() as u64, body.len() as u64)
712        );
713
714        cleanup_file(&output).expect("cleanup output file");
715    }
716}