Skip to main content

upstream_rs/providers/http/
http_client.rs

1use anyhow::{Context, Result, bail};
2use chrono::{DateTime, Utc};
3use reqwest::{Client, StatusCode, header};
4use std::collections::HashSet;
5use std::path::Path;
6
7use crate::models::common::enums::Filetype;
8use crate::providers::download_handler;
9use crate::utils::filename_parser::parse_filetype;
10
11#[derive(Debug, Clone)]
12pub struct HttpAssetInfo {
13    pub download_url: String,
14    pub name: String,
15    pub size: u64,
16    pub last_modified: Option<DateTime<Utc>>,
17    pub etag: Option<String>,
18}
19
20#[derive(Debug, Clone)]
21pub enum ConditionalProbeResult {
22    NotModified,
23    Asset(HttpAssetInfo),
24}
25
26#[derive(Debug, Clone)]
27pub enum ConditionalDiscoveryResult {
28    NotModified,
29    Assets(Vec<HttpAssetInfo>),
30}
31
32#[derive(Debug, Clone)]
33pub struct HttpClient {
34    client: Client,
35}
36
37impl HttpClient {
38    fn format_http_date(dt: DateTime<Utc>) -> String {
39        dt.format("%a, %d %b %Y %H:%M:%S GMT").to_string()
40    }
41
42    fn add_if_modified_since(
43        mut request: reqwest::RequestBuilder,
44        last_upgraded: Option<DateTime<Utc>>,
45    ) -> reqwest::RequestBuilder {
46        if let Some(ts) = last_upgraded {
47            request = request.header(header::IF_MODIFIED_SINCE, Self::format_http_date(ts));
48        }
49        request
50    }
51
52    fn parse_last_modified(value: Option<&header::HeaderValue>) -> Option<DateTime<Utc>> {
53        let raw = value?.to_str().ok()?;
54        DateTime::parse_from_rfc2822(raw)
55            .ok()
56            .map(|dt| dt.with_timezone(&Utc))
57    }
58
59    fn parse_etag(value: Option<&header::HeaderValue>) -> Option<String> {
60        value
61            .and_then(|v| v.to_str().ok())
62            .map(str::trim)
63            .map(|s| s.trim_matches('"').to_string())
64            .filter(|s| !s.is_empty())
65    }
66
67    fn attribute_has_boundary(html: &str, index: usize, attribute: &str) -> bool {
68        let bytes = html.as_bytes();
69        let valid_start = index == 0
70            || bytes
71                .get(index.saturating_sub(1))
72                .map(|b| !b.is_ascii_alphanumeric() && *b != b'-')
73                .unwrap_or(true);
74        let end = index + attribute.len();
75        let valid_end = bytes
76            .get(end)
77            .map(|b| *b == b'=' || b.is_ascii_whitespace())
78            .unwrap_or(false);
79
80        valid_start && valid_end
81    }
82
83    pub fn new() -> Result<Self> {
84        let mut headers = header::HeaderMap::new();
85
86        let user_agent = format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"));
87        headers.insert(
88            header::USER_AGENT,
89            header::HeaderValue::from_str(&user_agent)
90                .context("Failed to create user agent header")?,
91        );
92
93        let client = Client::builder()
94            .default_headers(headers)
95            .build()
96            .context("Failed to build HTTP client")?;
97
98        Ok(Self { client })
99    }
100
101    /// Normalize provider inputs so bare hosts/slugs become HTTPS URLs.
102    pub fn normalize_url(url_or_slug: &str) -> String {
103        let raw = url_or_slug.trim();
104        if raw.starts_with("http://") || raw.starts_with("https://") {
105            raw.to_string()
106        } else {
107            format!("https://{}", raw)
108        }
109    }
110
111    /// Extract likely download URL attribute values from HTML without a full DOM parser.
112    fn extract_link_values(html: &str) -> Vec<String> {
113        let attributes = [
114            "href",
115            "src",
116            "data-href",
117            "data-url",
118            "data-download",
119            "data-download-url",
120        ];
121        let mut values = Vec::new();
122        let lower = html.to_lowercase();
123        let bytes = lower.as_bytes();
124        let mut i = 0_usize;
125
126        while i < bytes.len() {
127            let Some((attribute, attr_offset)) = attributes
128                .iter()
129                .filter_map(|attribute| {
130                    lower[i..]
131                        .find(attribute)
132                        .map(|offset| (*attribute, offset))
133                })
134                .min_by(|(left_attr, left_offset), (right_attr, right_offset)| {
135                    left_offset
136                        .cmp(right_offset)
137                        .then_with(|| right_attr.len().cmp(&left_attr.len()))
138                })
139            else {
140                break;
141            };
142
143            i += attr_offset;
144            if !Self::attribute_has_boundary(&lower, i, attribute) {
145                i += 1;
146                continue;
147            }
148
149            let mut j = i + attribute.len();
150            while j < bytes.len() && bytes[j].is_ascii_whitespace() {
151                j += 1;
152            }
153            if j >= bytes.len() || bytes[j] != b'=' {
154                i += 1;
155                continue;
156            }
157            j += 1;
158            while j < bytes.len() && bytes[j].is_ascii_whitespace() {
159                j += 1;
160            }
161            if j >= bytes.len() {
162                break;
163            }
164
165            let quote = bytes[j];
166            if quote == b'"' || quote == b'\'' {
167                let start = j + 1;
168                let mut end = start;
169                while end < bytes.len() && bytes[end] != quote {
170                    end += 1;
171                }
172                if end <= html.len() && start <= end {
173                    let href = html[start..end].trim();
174                    if !href.is_empty() {
175                        values.push(href.to_string());
176                    }
177                }
178                i = end.saturating_add(1);
179                continue;
180            }
181
182            let start = j;
183            let mut end = start;
184            while end < bytes.len() && !bytes[end].is_ascii_whitespace() && bytes[end] != b'>' {
185                end += 1;
186            }
187            if end <= html.len() && start < end {
188                let href = html[start..end].trim();
189                if !href.is_empty() {
190                    values.push(href.to_string());
191                }
192            }
193
194            i = j.saturating_add(1);
195        }
196
197        values
198    }
199
200    fn to_asset_info(url: &str, headers: &header::HeaderMap) -> HttpAssetInfo {
201        HttpAssetInfo {
202            name: Self::file_name_from_url(url),
203            download_url: url.to_string(),
204            size: headers
205                .get(header::CONTENT_LENGTH)
206                .and_then(|v| v.to_str().ok())
207                .and_then(|s| s.parse::<u64>().ok())
208                .unwrap_or(0),
209            last_modified: Self::parse_last_modified(headers.get(header::LAST_MODIFIED)),
210            etag: Self::parse_etag(headers.get(header::ETAG)),
211        }
212    }
213
214    /// Convert discovered links into unique, non-checksum HTTP assets.
215    fn extract_assets_from_html(
216        base: &reqwest::Url,
217        html: &str,
218        page_headers: &header::HeaderMap,
219    ) -> Vec<HttpAssetInfo> {
220        let hrefs = Self::extract_link_values(html);
221        let page_last_modified = Self::parse_last_modified(page_headers.get(header::LAST_MODIFIED));
222        let page_etag = Self::parse_etag(page_headers.get(header::ETAG));
223
224        let mut seen = HashSet::new();
225        let mut assets = Vec::new();
226        for href in hrefs {
227            if href.starts_with('#')
228                || href.starts_with("javascript:")
229                || href.starts_with("mailto:")
230                || href.starts_with("tel:")
231            {
232                continue;
233            }
234
235            let Ok(joined) = base.join(&href) else {
236                continue;
237            };
238            if joined.scheme() != "http" && joined.scheme() != "https" {
239                continue;
240            }
241
242            let joined_str = joined.to_string();
243            let name = Self::file_name_from_url(&joined_str);
244            if name.is_empty() {
245                continue;
246            }
247
248            if parse_filetype(&name) == Filetype::Checksum {
249                continue;
250            }
251
252            if seen.insert(joined_str.clone()) {
253                assets.push(HttpAssetInfo {
254                    download_url: joined_str,
255                    name,
256                    size: 0,
257                    last_modified: page_last_modified,
258                    etag: page_etag.clone(),
259                });
260            }
261        }
262        assets
263    }
264
265    /// Discover downloadable assets from an HTTP endpoint with optional
266    /// `If-Modified-Since` behavior.
267    pub async fn discover_assets_if_modified_since(
268        &self,
269        url_or_slug: &str,
270        last_upgraded: Option<DateTime<Utc>>,
271    ) -> Result<ConditionalDiscoveryResult> {
272        let url = Self::normalize_url(url_or_slug);
273        let response = Self::add_if_modified_since(self.client.get(&url), last_upgraded)
274            .send()
275            .await
276            .context(format!("Failed to send request to {}", url))?;
277
278        if response.status() == StatusCode::NOT_MODIFIED {
279            return Ok(ConditionalDiscoveryResult::NotModified);
280        }
281
282        response
283            .error_for_status_ref()
284            .context(format!("HTTP server returned error for {}", url))?;
285
286        let final_url = response.url().to_string();
287        let content_type = response
288            .headers()
289            .get(header::CONTENT_TYPE)
290            .and_then(|v| v.to_str().ok())
291            .unwrap_or("")
292            .to_lowercase();
293        let response_headers = response.headers().clone();
294
295        if !content_type.contains("text/html") {
296            return Ok(ConditionalDiscoveryResult::Assets(vec![
297                Self::to_asset_info(&final_url, response.headers()),
298            ]));
299        }
300
301        let base = reqwest::Url::parse(&final_url)
302            .context(format!("Failed to parse URL '{}'", final_url))?;
303        let body = response.text().await.context("Failed to read HTML body")?;
304        let assets = Self::extract_assets_from_html(&base, &body, &response_headers);
305
306        if assets.is_empty() {
307            Ok(ConditionalDiscoveryResult::Assets(vec![
308                Self::to_asset_info(&final_url, &response_headers),
309            ]))
310        } else {
311            Ok(ConditionalDiscoveryResult::Assets(assets))
312        }
313    }
314
315    /// Derive a filename from URL path segments with a safe fallback.
316    pub fn file_name_from_url(url: &str) -> String {
317        let without_fragment = url.split('#').next().unwrap_or(url);
318        let without_query = without_fragment
319            .split('?')
320            .next()
321            .unwrap_or(without_fragment);
322        let candidate = without_query.rsplit('/').next().unwrap_or("").trim();
323
324        if candidate.is_empty() {
325            "download.bin".to_string()
326        } else {
327            candidate.to_string()
328        }
329    }
330
331    pub async fn probe_asset(&self, url_or_slug: &str) -> Result<HttpAssetInfo> {
332        match self
333            .probe_asset_if_modified_since(url_or_slug, None)
334            .await?
335        {
336            ConditionalProbeResult::NotModified => {
337                bail!("Unexpected 304 Not Modified response without conditional timestamp")
338            }
339            ConditionalProbeResult::Asset(asset) => Ok(asset),
340        }
341    }
342
343    pub async fn probe_asset_if_modified_since(
344        &self,
345        url_or_slug: &str,
346        last_upgraded: Option<DateTime<Utc>>,
347    ) -> Result<ConditionalProbeResult> {
348        let url = Self::normalize_url(url_or_slug);
349
350        let head_resp = Self::add_if_modified_since(self.client.head(&url), last_upgraded)
351            .send()
352            .await;
353
354        let (size, last_modified, etag) = match head_resp {
355            Ok(resp) if resp.status() == StatusCode::NOT_MODIFIED => {
356                return Ok(ConditionalProbeResult::NotModified);
357            }
358            Ok(resp) if resp.status().is_success() => {
359                let last_modified =
360                    Self::parse_last_modified(resp.headers().get(header::LAST_MODIFIED));
361                let etag = Self::parse_etag(resp.headers().get(header::ETAG));
362                (resp.content_length().unwrap_or(0), last_modified, etag)
363            }
364            Ok(resp)
365                if resp.status() == StatusCode::METHOD_NOT_ALLOWED
366                    || resp.status() == StatusCode::NOT_IMPLEMENTED =>
367            {
368                let get_resp = Self::add_if_modified_since(self.client.get(&url), last_upgraded)
369                    .send()
370                    .await
371                    .context(format!("Failed to send request to {}", url))?;
372
373                if get_resp.status() == StatusCode::NOT_MODIFIED {
374                    return Ok(ConditionalProbeResult::NotModified);
375                }
376
377                get_resp
378                    .error_for_status_ref()
379                    .context(format!("HTTP server returned error for {}", url))?;
380                let last_modified =
381                    Self::parse_last_modified(get_resp.headers().get(header::LAST_MODIFIED));
382                let etag = Self::parse_etag(get_resp.headers().get(header::ETAG));
383                (get_resp.content_length().unwrap_or(0), last_modified, etag)
384            }
385            Ok(resp) => {
386                bail!("HTTP server returned {} for {}", resp.status(), url);
387            }
388            Err(_) => {
389                let get_resp = Self::add_if_modified_since(self.client.get(&url), last_upgraded)
390                    .send()
391                    .await
392                    .context(format!("Failed to send request to {}", url))?;
393
394                if get_resp.status() == StatusCode::NOT_MODIFIED {
395                    return Ok(ConditionalProbeResult::NotModified);
396                }
397
398                get_resp
399                    .error_for_status_ref()
400                    .context(format!("HTTP server returned error for {}", url))?;
401                let last_modified =
402                    Self::parse_last_modified(get_resp.headers().get(header::LAST_MODIFIED));
403                let etag = Self::parse_etag(get_resp.headers().get(header::ETAG));
404                (get_resp.content_length().unwrap_or(0), last_modified, etag)
405            }
406        };
407
408        Ok(ConditionalProbeResult::Asset(HttpAssetInfo {
409            name: Self::file_name_from_url(&url),
410            download_url: url,
411            size,
412            last_modified,
413            etag,
414        }))
415    }
416
417    pub async fn download_file<F>(
418        &self,
419        url: &str,
420        destination: &Path,
421        progress: &mut Option<F>,
422    ) -> Result<()>
423    where
424        F: FnMut(u64, u64),
425    {
426        download_handler::download_file(&self.client, url, destination, progress).await
427    }
428}
429
430#[cfg(test)]
431mod tests {
432    use super::{ConditionalDiscoveryResult, ConditionalProbeResult, HttpClient};
433    use chrono::Utc;
434    use std::io::{BufRead, BufReader, Write};
435    use std::net::TcpListener;
436    use std::path::{Path, PathBuf};
437    use std::sync::mpsc;
438    use std::thread;
439    use std::time::{SystemTime, UNIX_EPOCH};
440    use std::{fs, io};
441
442    fn spawn_test_server<F>(max_requests: usize, handler: F) -> String
443    where
444        F: Fn(&str, &str) -> String + Send + 'static,
445    {
446        let (tx, rx) = mpsc::channel();
447        thread::spawn(move || {
448            let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
449            let addr = listener.local_addr().expect("resolve local addr");
450            tx.send(addr).expect("send test server addr");
451
452            for _ in 0..max_requests {
453                let (mut stream, _) = listener.accept().expect("accept request");
454                let cloned = stream.try_clone().expect("clone stream");
455                let mut reader = BufReader::new(cloned);
456
457                let mut request_line = String::new();
458                reader
459                    .read_line(&mut request_line)
460                    .expect("read request line");
461                let mut parts = request_line.split_whitespace();
462                let method = parts.next().unwrap_or("");
463                let path = parts.next().unwrap_or("/");
464
465                let mut line = String::new();
466                loop {
467                    line.clear();
468                    reader.read_line(&mut line).expect("read request headers");
469                    if line == "\r\n" || line.is_empty() {
470                        break;
471                    }
472                }
473
474                let response = handler(method, path);
475                stream
476                    .write_all(response.as_bytes())
477                    .expect("write response");
478                stream.flush().expect("flush response");
479            }
480        });
481
482        let addr = rx.recv().expect("receive server address");
483        format!("http://{}", addr)
484    }
485
486    fn http_response(status_line: &str, headers: &[(&str, &str)], body: &str) -> String {
487        let mut out = format!("{status_line}\r\n");
488        for (k, v) in headers {
489            out.push_str(&format!("{k}: {v}\r\n"));
490        }
491        out.push_str("\r\n");
492        out.push_str(body);
493        out
494    }
495
496    fn temp_file_path(name: &str) -> PathBuf {
497        let nanos = SystemTime::now()
498            .duration_since(UNIX_EPOCH)
499            .map(|d| d.as_nanos())
500            .unwrap_or(0);
501        std::env::temp_dir().join(format!("upstream-http-test-{name}-{nanos}.bin"))
502    }
503
504    fn cleanup_file(path: &Path) -> io::Result<()> {
505        if path.exists() {
506            fs::remove_file(path)?;
507        }
508        Ok(())
509    }
510
511    #[test]
512    fn normalize_url_and_file_name_from_url_behave_as_expected() {
513        assert_eq!(
514            HttpClient::normalize_url("example.com/a"),
515            "https://example.com/a"
516        );
517        assert_eq!(
518            HttpClient::normalize_url("http://example.com/a"),
519            "http://example.com/a"
520        );
521
522        assert_eq!(
523            HttpClient::file_name_from_url("https://x.invalid/path/tool.tar.gz?x=1#frag"),
524            "tool.tar.gz"
525        );
526        assert_eq!(
527            HttpClient::file_name_from_url("https://x.invalid/path/"),
528            "download.bin"
529        );
530    }
531
532    #[tokio::test]
533    async fn discover_assets_extracts_and_filters_html_links() {
534        let html = r##"
535                <html><body>
536                    <a href="tool-v1.2.3-linux.tar.gz">main</a>
537                    <a href="/downloads/tool-v1.2.3-linux.tar.gz">duplicate</a>
538                    <a href="tool-v1.2.3.sha256">checksum</a>
539                    <a href="mailto:test@example.com">mail</a>
540                    <a href="#anchor">anchor</a>
541                    <a href="https://example.invalid/tool-v1.2.3-macos.zip">mac</a>
542                    <button data-download-url="/tool-v1.2.3-windows.zip">win</button>
543                </body></html>
544            "##;
545        let body = html.to_string();
546        let last_modified = "Tue, 10 Feb 2026 15:04:05 GMT".to_string();
547        let server = spawn_test_server(1, move |_, _| {
548            http_response(
549                "HTTP/1.1 200 OK",
550                &[
551                    ("Content-Type", "text/html"),
552                    ("Last-Modified", &last_modified),
553                    ("Content-Length", &body.len().to_string()),
554                    ("Connection", "close"),
555                ],
556                &body,
557            )
558        });
559        let client = HttpClient::new().expect("client");
560
561        let result = client
562            .discover_assets_if_modified_since(&server, None)
563            .await
564            .expect("discover assets");
565
566        match result {
567            ConditionalDiscoveryResult::NotModified => panic!("unexpected not modified"),
568            ConditionalDiscoveryResult::Assets(assets) => {
569                assert_eq!(assets.len(), 4);
570                assert!(
571                    assets
572                        .iter()
573                        .any(|a| a.name.ends_with("tool-v1.2.3-linux.tar.gz"))
574                );
575                assert!(assets.iter().all(|a| !a.name.ends_with(".sha256")));
576                assert!(assets.iter().all(|a| a.last_modified.is_some()));
577                assert!(
578                    assets
579                        .iter()
580                        .any(|a| a.name.ends_with("tool-v1.2.3-windows.zip"))
581                );
582            }
583        }
584    }
585
586    #[test]
587    fn extract_link_values_accepts_spaced_and_unquoted_attributes() {
588        let html = r#"
589            <a href = "tool-a.zip">quoted with spaces</a>
590            <a HREF='tool-b.tar.gz'>uppercase single quoted</a>
591            <a href=tool-c.7z>unquoted</a>
592            <button data-download-url = /tool-d.zip>data attr</button>
593        "#;
594
595        let values = HttpClient::extract_link_values(html);
596
597        assert!(values.contains(&"tool-a.zip".to_string()));
598        assert!(values.contains(&"tool-b.tar.gz".to_string()));
599        assert!(values.contains(&"tool-c.7z".to_string()));
600        assert!(values.contains(&"/tool-d.zip".to_string()));
601    }
602
603    #[tokio::test]
604    async fn probe_asset_if_modified_since_returns_not_modified_on_304() {
605        let server = spawn_test_server(1, move |method, _| {
606            assert_eq!(method, "HEAD");
607            http_response("HTTP/1.1 304 Not Modified", &[("Connection", "close")], "")
608        });
609        let client = HttpClient::new().expect("client");
610
611        let result = client
612            .probe_asset_if_modified_since(&server, Some(Utc::now()))
613            .await
614            .expect("probe");
615        assert!(matches!(result, ConditionalProbeResult::NotModified));
616    }
617
618    #[tokio::test]
619    async fn probe_asset_if_modified_since_falls_back_to_get_on_405_head() {
620        let last_modified = "Tue, 10 Feb 2026 15:04:05 GMT".to_string();
621        let etag = "\"abc123\"".to_string();
622        let server = spawn_test_server(2, move |method, _| match method {
623            "HEAD" => http_response(
624                "HTTP/1.1 405 Method Not Allowed",
625                &[("Connection", "close"), ("Content-Length", "0")],
626                "",
627            ),
628            "GET" => http_response(
629                "HTTP/1.1 200 OK",
630                &[
631                    ("Connection", "close"),
632                    ("Content-Length", "11"),
633                    ("Last-Modified", &last_modified),
634                    ("ETag", &etag),
635                ],
636                "hello world",
637            ),
638            _ => http_response(
639                "HTTP/1.1 500 Internal Server Error",
640                &[("Connection", "close"), ("Content-Length", "0")],
641                "",
642            ),
643        });
644        let client = HttpClient::new().expect("client");
645
646        let result = client
647            .probe_asset_if_modified_since(&format!("{server}/tool-v2.3.4.tar.gz"), None)
648            .await
649            .expect("probe fallback");
650
651        match result {
652            ConditionalProbeResult::NotModified => panic!("unexpected not modified"),
653            ConditionalProbeResult::Asset(asset) => {
654                assert_eq!(asset.size, 11);
655                assert_eq!(asset.etag.as_deref(), Some("abc123"));
656                assert!(asset.last_modified.is_some());
657                assert_eq!(asset.name, "tool-v2.3.4.tar.gz");
658            }
659        }
660    }
661
662    #[tokio::test]
663    async fn download_file_writes_bytes_and_reports_progress() {
664        let body = "stream-body-data".to_string();
665        let len = body.len().to_string();
666        let body_for_server = body.clone();
667        let server = spawn_test_server(1, move |method, _| {
668            assert_eq!(method, "GET");
669            http_response(
670                "HTTP/1.1 200 OK",
671                &[
672                    ("Connection", "close"),
673                    ("Content-Type", "application/octet-stream"),
674                    ("Content-Length", &len),
675                ],
676                &body_for_server,
677            )
678        });
679        let client = HttpClient::new().expect("client");
680        let output = temp_file_path("download");
681        let mut progress = Vec::new();
682        let mut cb = Some(|downloaded: u64, total: u64| {
683            progress.push((downloaded, total));
684        });
685
686        client
687            .download_file(&server, &output, &mut cb)
688            .await
689            .expect("download file");
690
691        assert_eq!(fs::read_to_string(&output).expect("read output file"), body);
692        assert!(!progress.is_empty());
693        assert_eq!(
694            progress.last().copied().expect("final progress"),
695            (body.len() as u64, body.len() as u64)
696        );
697
698        cleanup_file(&output).expect("cleanup output file");
699    }
700}