Skip to main content

upstream_rs/providers/http/
http_client.rs

1use anyhow::{Context, Result, bail};
2use chrono::{DateTime, Utc};
3use reqwest::{Client, StatusCode, header};
4use std::collections::HashSet;
5use std::path::Path;
6
7use crate::models::common::enums::Filetype;
8use crate::models::upstream::DownloadConfig;
9use crate::providers::download_handler;
10use crate::utils::filename_parser::parse_filetype;
11
12#[derive(Debug, Clone)]
13pub struct HttpAssetInfo {
14    pub download_url: String,
15    pub name: String,
16    pub size: u64,
17    pub last_modified: Option<DateTime<Utc>>,
18    pub etag: Option<String>,
19}
20
21#[derive(Debug, Clone)]
22pub enum ConditionalProbeResult {
23    NotModified,
24    Asset(HttpAssetInfo),
25}
26
27#[derive(Debug, Clone)]
28pub enum ConditionalDiscoveryResult {
29    NotModified,
30    Assets(Vec<HttpAssetInfo>),
31}
32
33#[derive(Debug, Clone)]
34pub struct HttpClient {
35    client: Client,
36    download_config: DownloadConfig,
37}
38
39impl HttpClient {
40    fn format_http_date(dt: DateTime<Utc>) -> String {
41        dt.format("%a, %d %b %Y %H:%M:%S GMT").to_string()
42    }
43
44    fn add_if_modified_since(
45        mut request: reqwest::RequestBuilder,
46        last_upgraded: Option<DateTime<Utc>>,
47    ) -> reqwest::RequestBuilder {
48        if let Some(ts) = last_upgraded {
49            request = request.header(header::IF_MODIFIED_SINCE, Self::format_http_date(ts));
50        }
51        request
52    }
53
54    fn parse_last_modified(value: Option<&header::HeaderValue>) -> Option<DateTime<Utc>> {
55        let raw = value?.to_str().ok()?;
56        DateTime::parse_from_rfc2822(raw)
57            .ok()
58            .map(|dt| dt.with_timezone(&Utc))
59    }
60
61    fn parse_etag(value: Option<&header::HeaderValue>) -> Option<String> {
62        value
63            .and_then(|v| v.to_str().ok())
64            .map(str::trim)
65            .map(|s| s.trim_matches('"').to_string())
66            .filter(|s| !s.is_empty())
67    }
68
69    fn attribute_has_boundary(html: &str, index: usize, attribute: &str) -> bool {
70        let bytes = html.as_bytes();
71        let valid_start = index == 0
72            || bytes
73                .get(index.saturating_sub(1))
74                .map(|b| !b.is_ascii_alphanumeric() && *b != b'-')
75                .unwrap_or(true);
76        let end = index + attribute.len();
77        let valid_end = bytes
78            .get(end)
79            .map(|b| *b == b'=' || b.is_ascii_whitespace())
80            .unwrap_or(false);
81
82        valid_start && valid_end
83    }
84
85    pub fn new(download_config: DownloadConfig) -> Result<Self> {
86        let mut headers = header::HeaderMap::new();
87
88        let user_agent = format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"));
89        headers.insert(
90            header::USER_AGENT,
91            header::HeaderValue::from_str(&user_agent)
92                .context("Failed to create user agent header")?,
93        );
94
95        let client = Client::builder()
96            .default_headers(headers)
97            .build()
98            .context("Failed to build HTTP client")?;
99
100        Ok(Self {
101            client,
102            download_config,
103        })
104    }
105
106    /// Normalize provider inputs so bare hosts/slugs become HTTPS URLs.
107    pub fn normalize_url(url_or_slug: &str) -> String {
108        let raw = url_or_slug.trim();
109        if raw.starts_with("http://") || raw.starts_with("https://") {
110            raw.to_string()
111        } else {
112            format!("https://{}", raw)
113        }
114    }
115
116    /// Extract likely download URL attribute values from HTML without a full DOM parser.
117    fn extract_link_values(html: &str) -> Vec<String> {
118        let attributes = [
119            "href",
120            "src",
121            "data-href",
122            "data-url",
123            "data-download",
124            "data-download-url",
125        ];
126        let mut values = Vec::new();
127        let lower = html.to_lowercase();
128        let bytes = lower.as_bytes();
129        let mut i = 0_usize;
130
131        while i < bytes.len() {
132            let Some((attribute, attr_offset)) = attributes
133                .iter()
134                .filter_map(|attribute| {
135                    lower[i..]
136                        .find(attribute)
137                        .map(|offset| (*attribute, offset))
138                })
139                .min_by(|(left_attr, left_offset), (right_attr, right_offset)| {
140                    left_offset
141                        .cmp(right_offset)
142                        .then_with(|| right_attr.len().cmp(&left_attr.len()))
143                })
144            else {
145                break;
146            };
147
148            i += attr_offset;
149            if !Self::attribute_has_boundary(&lower, i, attribute) {
150                i += 1;
151                continue;
152            }
153
154            let mut j = i + attribute.len();
155            while j < bytes.len() && bytes[j].is_ascii_whitespace() {
156                j += 1;
157            }
158            if j >= bytes.len() || bytes[j] != b'=' {
159                i += 1;
160                continue;
161            }
162            j += 1;
163            while j < bytes.len() && bytes[j].is_ascii_whitespace() {
164                j += 1;
165            }
166            if j >= bytes.len() {
167                break;
168            }
169
170            let quote = bytes[j];
171            if quote == b'"' || quote == b'\'' {
172                let start = j + 1;
173                let mut end = start;
174                while end < bytes.len() && bytes[end] != quote {
175                    end += 1;
176                }
177                if end <= html.len() && start <= end {
178                    let href = html[start..end].trim();
179                    if !href.is_empty() {
180                        values.push(href.to_string());
181                    }
182                }
183                i = end.saturating_add(1);
184                continue;
185            }
186
187            let start = j;
188            let mut end = start;
189            while end < bytes.len() && !bytes[end].is_ascii_whitespace() && bytes[end] != b'>' {
190                end += 1;
191            }
192            if end <= html.len() && start < end {
193                let href = html[start..end].trim();
194                if !href.is_empty() {
195                    values.push(href.to_string());
196                }
197            }
198
199            i = j.saturating_add(1);
200        }
201
202        values
203    }
204
205    fn to_asset_info(url: &str, headers: &header::HeaderMap) -> HttpAssetInfo {
206        HttpAssetInfo {
207            name: Self::file_name_from_url(url),
208            download_url: url.to_string(),
209            size: headers
210                .get(header::CONTENT_LENGTH)
211                .and_then(|v| v.to_str().ok())
212                .and_then(|s| s.parse::<u64>().ok())
213                .unwrap_or(0),
214            last_modified: Self::parse_last_modified(headers.get(header::LAST_MODIFIED)),
215            etag: Self::parse_etag(headers.get(header::ETAG)),
216        }
217    }
218
219    /// Convert discovered links into unique, non-checksum HTTP assets.
220    fn extract_assets_from_html(
221        base: &reqwest::Url,
222        html: &str,
223        page_headers: &header::HeaderMap,
224    ) -> Vec<HttpAssetInfo> {
225        let hrefs = Self::extract_link_values(html);
226        let page_last_modified = Self::parse_last_modified(page_headers.get(header::LAST_MODIFIED));
227        let page_etag = Self::parse_etag(page_headers.get(header::ETAG));
228
229        let mut seen = HashSet::new();
230        let mut assets = Vec::new();
231        for href in hrefs {
232            if href.starts_with('#')
233                || href.starts_with("javascript:")
234                || href.starts_with("mailto:")
235                || href.starts_with("tel:")
236            {
237                continue;
238            }
239
240            let Ok(joined) = base.join(&href) else {
241                continue;
242            };
243            if joined.scheme() != "http" && joined.scheme() != "https" {
244                continue;
245            }
246
247            let joined_str = joined.to_string();
248            let name = Self::file_name_from_url(&joined_str);
249            if name.is_empty() {
250                continue;
251            }
252
253            if parse_filetype(&name) == Filetype::Checksum {
254                continue;
255            }
256
257            if seen.insert(joined_str.clone()) {
258                assets.push(HttpAssetInfo {
259                    download_url: joined_str,
260                    name,
261                    size: 0,
262                    last_modified: page_last_modified,
263                    etag: page_etag.clone(),
264                });
265            }
266        }
267        assets
268    }
269
270    /// Discover downloadable assets from an HTTP endpoint with optional
271    /// `If-Modified-Since` behavior.
272    pub async fn discover_assets_if_modified_since(
273        &self,
274        url_or_slug: &str,
275        last_upgraded: Option<DateTime<Utc>>,
276    ) -> Result<ConditionalDiscoveryResult> {
277        let url = Self::normalize_url(url_or_slug);
278        let response = Self::add_if_modified_since(self.client.get(&url), last_upgraded)
279            .send()
280            .await
281            .context(format!("Failed to send request to {}", url))?;
282
283        if response.status() == StatusCode::NOT_MODIFIED {
284            return Ok(ConditionalDiscoveryResult::NotModified);
285        }
286
287        response
288            .error_for_status_ref()
289            .context(format!("HTTP server returned error for {}", url))?;
290
291        let final_url = response.url().to_string();
292        let content_type = response
293            .headers()
294            .get(header::CONTENT_TYPE)
295            .and_then(|v| v.to_str().ok())
296            .unwrap_or("")
297            .to_lowercase();
298        let response_headers = response.headers().clone();
299
300        if !content_type.contains("text/html") {
301            return Ok(ConditionalDiscoveryResult::Assets(vec![
302                Self::to_asset_info(&final_url, response.headers()),
303            ]));
304        }
305
306        let base = reqwest::Url::parse(&final_url)
307            .context(format!("Failed to parse URL '{}'", final_url))?;
308        let body = response.text().await.context("Failed to read HTML body")?;
309        let assets = Self::extract_assets_from_html(&base, &body, &response_headers);
310
311        if assets.is_empty() {
312            Ok(ConditionalDiscoveryResult::Assets(vec![
313                Self::to_asset_info(&final_url, &response_headers),
314            ]))
315        } else {
316            Ok(ConditionalDiscoveryResult::Assets(assets))
317        }
318    }
319
320    /// Derive a filename from URL path segments with a safe fallback.
321    pub fn file_name_from_url(url: &str) -> String {
322        let without_fragment = url.split('#').next().unwrap_or(url);
323        let without_query = without_fragment
324            .split('?')
325            .next()
326            .unwrap_or(without_fragment);
327        let candidate = without_query.rsplit('/').next().unwrap_or("").trim();
328
329        if candidate.is_empty() {
330            "download.bin".to_string()
331        } else {
332            candidate.to_string()
333        }
334    }
335
336    pub async fn probe_asset(&self, url_or_slug: &str) -> Result<HttpAssetInfo> {
337        match self
338            .probe_asset_if_modified_since(url_or_slug, None)
339            .await?
340        {
341            ConditionalProbeResult::NotModified => {
342                bail!("Unexpected 304 Not Modified response without conditional timestamp")
343            }
344            ConditionalProbeResult::Asset(asset) => Ok(asset),
345        }
346    }
347
348    pub async fn probe_asset_if_modified_since(
349        &self,
350        url_or_slug: &str,
351        last_upgraded: Option<DateTime<Utc>>,
352    ) -> Result<ConditionalProbeResult> {
353        let url = Self::normalize_url(url_or_slug);
354
355        let head_resp = Self::add_if_modified_since(self.client.head(&url), last_upgraded)
356            .send()
357            .await;
358
359        let (size, last_modified, etag) = match head_resp {
360            Ok(resp) if resp.status() == StatusCode::NOT_MODIFIED => {
361                return Ok(ConditionalProbeResult::NotModified);
362            }
363            Ok(resp) if resp.status().is_success() => {
364                let last_modified =
365                    Self::parse_last_modified(resp.headers().get(header::LAST_MODIFIED));
366                let etag = Self::parse_etag(resp.headers().get(header::ETAG));
367                (resp.content_length().unwrap_or(0), last_modified, etag)
368            }
369            Ok(resp)
370                if resp.status() == StatusCode::METHOD_NOT_ALLOWED
371                    || resp.status() == StatusCode::NOT_IMPLEMENTED =>
372            {
373                let get_resp = Self::add_if_modified_since(self.client.get(&url), last_upgraded)
374                    .send()
375                    .await
376                    .context(format!("Failed to send request to {}", url))?;
377
378                if get_resp.status() == StatusCode::NOT_MODIFIED {
379                    return Ok(ConditionalProbeResult::NotModified);
380                }
381
382                get_resp
383                    .error_for_status_ref()
384                    .context(format!("HTTP server returned error for {}", url))?;
385                let last_modified =
386                    Self::parse_last_modified(get_resp.headers().get(header::LAST_MODIFIED));
387                let etag = Self::parse_etag(get_resp.headers().get(header::ETAG));
388                (get_resp.content_length().unwrap_or(0), last_modified, etag)
389            }
390            Ok(resp) => {
391                bail!("HTTP server returned {} for {}", resp.status(), url);
392            }
393            Err(_) => {
394                let get_resp = Self::add_if_modified_since(self.client.get(&url), last_upgraded)
395                    .send()
396                    .await
397                    .context(format!("Failed to send request to {}", url))?;
398
399                if get_resp.status() == StatusCode::NOT_MODIFIED {
400                    return Ok(ConditionalProbeResult::NotModified);
401                }
402
403                get_resp
404                    .error_for_status_ref()
405                    .context(format!("HTTP server returned error for {}", url))?;
406                let last_modified =
407                    Self::parse_last_modified(get_resp.headers().get(header::LAST_MODIFIED));
408                let etag = Self::parse_etag(get_resp.headers().get(header::ETAG));
409                (get_resp.content_length().unwrap_or(0), last_modified, etag)
410            }
411        };
412
413        Ok(ConditionalProbeResult::Asset(HttpAssetInfo {
414            name: Self::file_name_from_url(&url),
415            download_url: url,
416            size,
417            last_modified,
418            etag,
419        }))
420    }
421
422    pub async fn download_file<F>(
423        &self,
424        url: &str,
425        destination: &Path,
426        progress: &mut Option<F>,
427    ) -> Result<()>
428    where
429        F: FnMut(u64, u64),
430    {
431        download_handler::download_file(
432            &self.client,
433            url,
434            destination,
435            progress,
436            self.download_config,
437        )
438        .await
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::{ConditionalDiscoveryResult, ConditionalProbeResult, HttpClient};
445    use chrono::Utc;
446    use std::io::{BufRead, BufReader, Write};
447    use std::net::TcpListener;
448    use std::path::{Path, PathBuf};
449    use std::sync::mpsc;
450    use std::thread;
451    use std::time::{SystemTime, UNIX_EPOCH};
452    use std::{fs, io};
453
454    fn spawn_test_server<F>(max_requests: usize, handler: F) -> String
455    where
456        F: Fn(&str, &str) -> String + Send + 'static,
457    {
458        let (tx, rx) = mpsc::channel();
459        thread::spawn(move || {
460            let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
461            let addr = listener.local_addr().expect("resolve local addr");
462            tx.send(addr).expect("send test server addr");
463
464            for _ in 0..max_requests {
465                let (mut stream, _) = listener.accept().expect("accept request");
466                let cloned = stream.try_clone().expect("clone stream");
467                let mut reader = BufReader::new(cloned);
468
469                let mut request_line = String::new();
470                reader
471                    .read_line(&mut request_line)
472                    .expect("read request line");
473                let mut parts = request_line.split_whitespace();
474                let method = parts.next().unwrap_or("");
475                let path = parts.next().unwrap_or("/");
476
477                let mut line = String::new();
478                loop {
479                    line.clear();
480                    reader.read_line(&mut line).expect("read request headers");
481                    if line == "\r\n" || line.is_empty() {
482                        break;
483                    }
484                }
485
486                let response = handler(method, path);
487                stream
488                    .write_all(response.as_bytes())
489                    .expect("write response");
490                stream.flush().expect("flush response");
491            }
492        });
493
494        let addr = rx.recv().expect("receive server address");
495        format!("http://{}", addr)
496    }
497
498    fn http_response(status_line: &str, headers: &[(&str, &str)], body: &str) -> String {
499        let mut out = format!("{status_line}\r\n");
500        for (k, v) in headers {
501            out.push_str(&format!("{k}: {v}\r\n"));
502        }
503        out.push_str("\r\n");
504        out.push_str(body);
505        out
506    }
507
508    fn temp_file_path(name: &str) -> PathBuf {
509        let nanos = SystemTime::now()
510            .duration_since(UNIX_EPOCH)
511            .map(|d| d.as_nanos())
512            .unwrap_or(0);
513        std::env::temp_dir().join(format!("upstream-http-test-{name}-{nanos}.bin"))
514    }
515
516    fn cleanup_file(path: &Path) -> io::Result<()> {
517        if path.exists() {
518            fs::remove_file(path)?;
519        }
520        Ok(())
521    }
522
523    #[test]
524    fn normalize_url_and_file_name_from_url_behave_as_expected() {
525        assert_eq!(
526            HttpClient::normalize_url("example.com/a"),
527            "https://example.com/a"
528        );
529        assert_eq!(
530            HttpClient::normalize_url("http://example.com/a"),
531            "http://example.com/a"
532        );
533
534        assert_eq!(
535            HttpClient::file_name_from_url("https://x.invalid/path/tool.tar.gz?x=1#frag"),
536            "tool.tar.gz"
537        );
538        assert_eq!(
539            HttpClient::file_name_from_url("https://x.invalid/path/"),
540            "download.bin"
541        );
542    }
543
544    #[tokio::test]
545    async fn discover_assets_extracts_and_filters_html_links() {
546        let html = r##"
547                <html><body>
548                    <a href="tool-v1.2.3-linux.tar.gz">main</a>
549                    <a href="/downloads/tool-v1.2.3-linux.tar.gz">duplicate</a>
550                    <a href="tool-v1.2.3.sha256">checksum</a>
551                    <a href="mailto:test@example.com">mail</a>
552                    <a href="#anchor">anchor</a>
553                    <a href="https://example.invalid/tool-v1.2.3-macos.zip">mac</a>
554                    <button data-download-url="/tool-v1.2.3-windows.zip">win</button>
555                </body></html>
556            "##;
557        let body = html.to_string();
558        let last_modified = "Tue, 10 Feb 2026 15:04:05 GMT".to_string();
559        let server = spawn_test_server(1, move |_, _| {
560            http_response(
561                "HTTP/1.1 200 OK",
562                &[
563                    ("Content-Type", "text/html"),
564                    ("Last-Modified", &last_modified),
565                    ("Content-Length", &body.len().to_string()),
566                    ("Connection", "close"),
567                ],
568                &body,
569            )
570        });
571        let client = HttpClient::new(Default::default()).expect("client");
572
573        let result = client
574            .discover_assets_if_modified_since(&server, None)
575            .await
576            .expect("discover assets");
577
578        match result {
579            ConditionalDiscoveryResult::NotModified => panic!("unexpected not modified"),
580            ConditionalDiscoveryResult::Assets(assets) => {
581                assert_eq!(assets.len(), 4);
582                assert!(
583                    assets
584                        .iter()
585                        .any(|a| a.name.ends_with("tool-v1.2.3-linux.tar.gz"))
586                );
587                assert!(assets.iter().all(|a| !a.name.ends_with(".sha256")));
588                assert!(assets.iter().all(|a| a.last_modified.is_some()));
589                assert!(
590                    assets
591                        .iter()
592                        .any(|a| a.name.ends_with("tool-v1.2.3-windows.zip"))
593                );
594            }
595        }
596    }
597
598    #[test]
599    fn extract_link_values_accepts_spaced_and_unquoted_attributes() {
600        let html = r#"
601            <a href = "tool-a.zip">quoted with spaces</a>
602            <a HREF='tool-b.tar.gz'>uppercase single quoted</a>
603            <a href=tool-c.7z>unquoted</a>
604            <button data-download-url = /tool-d.zip>data attr</button>
605        "#;
606
607        let values = HttpClient::extract_link_values(html);
608
609        assert!(values.contains(&"tool-a.zip".to_string()));
610        assert!(values.contains(&"tool-b.tar.gz".to_string()));
611        assert!(values.contains(&"tool-c.7z".to_string()));
612        assert!(values.contains(&"/tool-d.zip".to_string()));
613    }
614
615    #[tokio::test]
616    async fn probe_asset_if_modified_since_returns_not_modified_on_304() {
617        let server = spawn_test_server(1, move |method, _| {
618            assert_eq!(method, "HEAD");
619            http_response("HTTP/1.1 304 Not Modified", &[("Connection", "close")], "")
620        });
621        let client = HttpClient::new(Default::default()).expect("client");
622
623        let result = client
624            .probe_asset_if_modified_since(&server, Some(Utc::now()))
625            .await
626            .expect("probe");
627        assert!(matches!(result, ConditionalProbeResult::NotModified));
628    }
629
630    #[tokio::test]
631    async fn probe_asset_if_modified_since_falls_back_to_get_on_405_head() {
632        let last_modified = "Tue, 10 Feb 2026 15:04:05 GMT".to_string();
633        let etag = "\"abc123\"".to_string();
634        let server = spawn_test_server(2, move |method, _| match method {
635            "HEAD" => http_response(
636                "HTTP/1.1 405 Method Not Allowed",
637                &[("Connection", "close"), ("Content-Length", "0")],
638                "",
639            ),
640            "GET" => http_response(
641                "HTTP/1.1 200 OK",
642                &[
643                    ("Connection", "close"),
644                    ("Content-Length", "11"),
645                    ("Last-Modified", &last_modified),
646                    ("ETag", &etag),
647                ],
648                "hello world",
649            ),
650            _ => http_response(
651                "HTTP/1.1 500 Internal Server Error",
652                &[("Connection", "close"), ("Content-Length", "0")],
653                "",
654            ),
655        });
656        let client = HttpClient::new(Default::default()).expect("client");
657
658        let result = client
659            .probe_asset_if_modified_since(&format!("{server}/tool-v2.3.4.tar.gz"), None)
660            .await
661            .expect("probe fallback");
662
663        match result {
664            ConditionalProbeResult::NotModified => panic!("unexpected not modified"),
665            ConditionalProbeResult::Asset(asset) => {
666                assert_eq!(asset.size, 11);
667                assert_eq!(asset.etag.as_deref(), Some("abc123"));
668                assert!(asset.last_modified.is_some());
669                assert_eq!(asset.name, "tool-v2.3.4.tar.gz");
670            }
671        }
672    }
673
674    #[tokio::test]
675    async fn download_file_writes_bytes_and_reports_progress() {
676        let body = "stream-body-data".to_string();
677        let len = body.len().to_string();
678        let body_for_server = body.clone();
679        let server = spawn_test_server(1, move |method, _| {
680            assert_eq!(method, "GET");
681            http_response(
682                "HTTP/1.1 200 OK",
683                &[
684                    ("Connection", "close"),
685                    ("Content-Type", "application/octet-stream"),
686                    ("Content-Length", &len),
687                ],
688                &body_for_server,
689            )
690        });
691        let client = HttpClient::new(Default::default()).expect("client");
692        let output = temp_file_path("download");
693        let mut progress = Vec::new();
694        let mut cb = Some(|downloaded: u64, total: u64| {
695            progress.push((downloaded, total));
696        });
697
698        client
699            .download_file(&server, &output, &mut cb)
700            .await
701            .expect("download file");
702
703        assert_eq!(fs::read_to_string(&output).expect("read output file"), body);
704        assert!(!progress.is_empty());
705        assert_eq!(
706            progress.last().copied().expect("final progress"),
707            (body.len() as u64, body.len() as u64)
708        );
709
710        cleanup_file(&output).expect("cleanup output file");
711    }
712}