tugger_common/
http.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at https://mozilla.org/MPL/2.0/.
4
5use {
6    anyhow::{anyhow, Context, Result},
7    fs2::FileExt,
8    log::warn,
9    sha2::Digest,
10    std::{fs::File, io::Read, path::Path},
11    url::Url,
12};
13
14/// Defines remote content that can be downloaded securely.
15pub struct RemoteContent {
16    /// Identifier used for configuring an override for the URL.
17    pub name: String,
18    pub url: String,
19    pub sha256: String,
20}
21
22fn sha256_path<P: AsRef<Path>>(path: P) -> Result<Vec<u8>> {
23    let mut hasher = sha2::Sha256::new();
24    let fh = std::fs::File::open(&path)?;
25    let mut reader = std::io::BufReader::new(fh);
26
27    let mut buffer = [0; 32768];
28
29    loop {
30        let count = reader.read(&mut buffer)?;
31        if count == 0 {
32            break;
33        }
34        hasher.update(&buffer[..count]);
35    }
36
37    Ok(hasher.finalize().to_vec())
38}
39
40/// Obtain an HTTP client, taking proxy environment variables into account.
41pub fn get_http_client() -> reqwest::Result<reqwest::blocking::Client> {
42    let mut builder = reqwest::blocking::ClientBuilder::new();
43
44    for (key, value) in std::env::vars() {
45        let key = key.to_lowercase();
46        if key.ends_with("_proxy") {
47            let end = key.len() - "_proxy".len();
48            let schema = &key[..end];
49
50            if let Ok(url) = Url::parse(&value) {
51                if let Some(Ok(proxy)) = match schema {
52                    "http" => Some(reqwest::Proxy::http(url.as_str())),
53                    "https" => Some(reqwest::Proxy::https(url.as_str())),
54                    _ => None,
55                } {
56                    builder = builder.proxy(proxy);
57                }
58            }
59        }
60    }
61
62    builder.build()
63}
64
65/// Fetch a URL and verify its SHA-256 matches expectations.
66pub fn download_and_verify(entry: &RemoteContent) -> Result<Vec<u8>> {
67    let url =
68        std::env::var(format!("{}_URL", &entry.name)).unwrap_or_else(|_err| entry.url.to_string());
69    warn!("downloading {}", url);
70    let url = Url::parse(&url)?;
71    let client = get_http_client()?;
72    let mut data: Vec<u8> = Vec::new();
73    if url.scheme() == "file" {
74        let file_path = url
75            .to_file_path()
76            .map_err(|_err: ()| anyhow!("bad url for {}: {}", entry.name, url))?;
77        let mut file = File::open(file_path)?;
78        file.read_to_end(&mut data)?;
79    } else {
80        let mut response = client.get(url).send()?;
81        response.read_to_end(&mut data)?;
82    }
83
84    let mut hasher = sha2::Sha256::new();
85    hasher.update(&data);
86
87    let url_hash = hasher.finalize().to_vec();
88    let expected_hash = hex::decode(&entry.sha256)?;
89
90    if expected_hash == url_hash {
91        warn!("verified SHA-256 is {}", entry.sha256);
92        Ok(data)
93    } else {
94        Err(anyhow!("hash mismatch of downloaded file"))
95    }
96}
97
98/// Ensure a URL with specified hash exists in a local filesystem path.
99pub fn download_to_path<P: AsRef<Path>>(entry: &RemoteContent, dest_path: P) -> Result<()> {
100    let dest_path = dest_path.as_ref();
101
102    if let Some(dest_dir) = dest_path.parent() {
103        std::fs::create_dir_all(dest_dir)
104            .with_context(|| format!("creating directory {}", dest_dir.display()))?;
105    }
106
107    let expected_hash = hex::decode(&entry.sha256)?;
108
109    let lock_path = dest_path.with_extension("lock");
110    let lock = std::fs::File::create(&lock_path)
111        .with_context(|| format!("creating {}", lock_path.display()))?;
112    lock.lock_exclusive().context("obtaining lock")?;
113
114    if dest_path.exists() {
115        let file_hash = sha256_path(dest_path)?;
116
117        if file_hash == expected_hash {
118            lock.unlock().context("unlocking")?;
119            return Ok(());
120        }
121
122        // Hash mismatch. Remove the current file.
123        std::fs::remove_file(dest_path)?;
124    }
125
126    let data = download_and_verify(entry).context("downloading with verification")?;
127    let temp_path = dest_path.with_file_name(format!(
128        "{}.tmp",
129        dest_path
130            .file_name()
131            .ok_or_else(|| anyhow!("unable to obtain file name"))?
132            .to_string_lossy()
133    ));
134
135    std::fs::write(&temp_path, data).context("writing data to temporary file")?;
136    std::fs::rename(&temp_path, dest_path).with_context(|| {
137        format!(
138            "renaming {} to {}",
139            temp_path.display(),
140            dest_path.display()
141        )
142    })?;
143    lock.unlock().context("unlocking")?;
144
145    Ok(())
146}