zed_http_client/
github_download.rs

1use std::{path::Path, pin::Pin, task::Poll};
2
3use anyhow::{Context, Result};
4use async_compression::futures::bufread::GzipDecoder;
5use futures::{AsyncRead, AsyncSeek, AsyncSeekExt, AsyncWrite, io::BufReader};
6use sha2::{Digest, Sha256};
7
8use crate::{HttpClient, github::AssetKind};
9
10#[derive(serde::Deserialize, serde::Serialize, Debug)]
11pub struct GithubBinaryMetadata {
12    pub metadata_version: u64,
13    pub digest: Option<String>,
14}
15
16impl GithubBinaryMetadata {
17    pub async fn read_from_file(metadata_path: &Path) -> Result<GithubBinaryMetadata> {
18        let metadata_content = async_fs::read_to_string(metadata_path)
19            .await
20            .with_context(|| format!("reading metadata file at {metadata_path:?}"))?;
21        serde_json::from_str(&metadata_content)
22            .with_context(|| format!("parsing metadata file at {metadata_path:?}"))
23    }
24
25    pub async fn write_to_file(&self, metadata_path: &Path) -> Result<()> {
26        let metadata_content = serde_json::to_string(self)
27            .with_context(|| format!("serializing metadata for {metadata_path:?}"))?;
28        async_fs::write(metadata_path, metadata_content.as_bytes())
29            .await
30            .with_context(|| format!("writing metadata file at {metadata_path:?}"))?;
31        Ok(())
32    }
33}
34
35pub async fn download_server_binary(
36    http_client: &dyn HttpClient,
37    url: &str,
38    digest: Option<&str>,
39    destination_path: &Path,
40    asset_kind: AssetKind,
41) -> Result<(), anyhow::Error> {
42    log::info!("downloading github artifact from {url}");
43    let mut response = http_client
44        .get(url, Default::default(), true)
45        .await
46        .with_context(|| format!("downloading release from {url}"))?;
47    let body = response.body_mut();
48    match digest {
49        Some(expected_sha_256) => {
50            let temp_asset_file = tempfile::NamedTempFile::new()
51                .with_context(|| format!("creating a temporary file for {url}"))?;
52            let (temp_asset_file, _temp_guard) = temp_asset_file.into_parts();
53            let mut writer = HashingWriter {
54                writer: async_fs::File::from(temp_asset_file),
55                hasher: Sha256::new(),
56            };
57            futures::io::copy(&mut BufReader::new(body), &mut writer)
58                .await
59                .with_context(|| {
60                    format!("saving archive contents into the temporary file for {url}",)
61                })?;
62            let asset_sha_256 = format!("{:x}", writer.hasher.finalize());
63
64            anyhow::ensure!(
65                asset_sha_256 == expected_sha_256,
66                "{url} asset got SHA-256 mismatch. Expected: {expected_sha_256}, Got: {asset_sha_256}",
67            );
68            writer
69                .writer
70                .seek(std::io::SeekFrom::Start(0))
71                .await
72                .with_context(|| format!("seeking temporary file {destination_path:?}",))?;
73            stream_file_archive(&mut writer.writer, url, destination_path, asset_kind)
74                .await
75                .with_context(|| {
76                    format!("extracting downloaded asset for {url} into {destination_path:?}",)
77                })?;
78        }
79        None => stream_response_archive(body, url, destination_path, asset_kind)
80            .await
81            .with_context(|| {
82                format!("extracting response for asset {url} into {destination_path:?}",)
83            })?,
84    }
85    Ok(())
86}
87
88async fn stream_response_archive(
89    response: impl AsyncRead + Unpin,
90    url: &str,
91    destination_path: &Path,
92    asset_kind: AssetKind,
93) -> Result<()> {
94    match asset_kind {
95        AssetKind::TarGz => extract_tar_gz(destination_path, url, response).await?,
96        AssetKind::Gz => extract_gz(destination_path, url, response).await?,
97        AssetKind::Zip => {
98            util::archive::extract_zip(destination_path, response).await?;
99        }
100    };
101    Ok(())
102}
103
104async fn stream_file_archive(
105    file_archive: impl AsyncRead + AsyncSeek + Unpin,
106    url: &str,
107    destination_path: &Path,
108    asset_kind: AssetKind,
109) -> Result<()> {
110    match asset_kind {
111        AssetKind::TarGz => extract_tar_gz(destination_path, url, file_archive).await?,
112        AssetKind::Gz => extract_gz(destination_path, url, file_archive).await?,
113        #[cfg(not(windows))]
114        AssetKind::Zip => {
115            util::archive::extract_seekable_zip(destination_path, file_archive).await?;
116        }
117        #[cfg(windows)]
118        AssetKind::Zip => {
119            util::archive::extract_zip(destination_path, file_archive).await?;
120        }
121    };
122    Ok(())
123}
124
125async fn extract_tar_gz(
126    destination_path: &Path,
127    url: &str,
128    from: impl AsyncRead + Unpin,
129) -> Result<(), anyhow::Error> {
130    let decompressed_bytes = GzipDecoder::new(BufReader::new(from));
131    let archive = async_tar::Archive::new(decompressed_bytes);
132    archive
133        .unpack(&destination_path)
134        .await
135        .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
136    Ok(())
137}
138
139async fn extract_gz(
140    destination_path: &Path,
141    url: &str,
142    from: impl AsyncRead + Unpin,
143) -> Result<(), anyhow::Error> {
144    let mut decompressed_bytes = GzipDecoder::new(BufReader::new(from));
145    let mut file = async_fs::File::create(&destination_path)
146        .await
147        .with_context(|| {
148            format!("creating a file {destination_path:?} for a download from {url}")
149        })?;
150    futures::io::copy(&mut decompressed_bytes, &mut file)
151        .await
152        .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
153    Ok(())
154}
155
156struct HashingWriter<W: AsyncWrite + Unpin> {
157    writer: W,
158    hasher: Sha256,
159}
160
161impl<W: AsyncWrite + Unpin> AsyncWrite for HashingWriter<W> {
162    fn poll_write(
163        mut self: Pin<&mut Self>,
164        cx: &mut std::task::Context<'_>,
165        buf: &[u8],
166    ) -> Poll<std::result::Result<usize, std::io::Error>> {
167        match Pin::new(&mut self.writer).poll_write(cx, buf) {
168            Poll::Ready(Ok(n)) => {
169                self.hasher.update(&buf[..n]);
170                Poll::Ready(Ok(n))
171            }
172            other => other,
173        }
174    }
175
176    fn poll_flush(
177        mut self: Pin<&mut Self>,
178        cx: &mut std::task::Context<'_>,
179    ) -> Poll<Result<(), std::io::Error>> {
180        Pin::new(&mut self.writer).poll_flush(cx)
181    }
182
183    fn poll_close(
184        mut self: Pin<&mut Self>,
185        cx: &mut std::task::Context<'_>,
186    ) -> Poll<std::result::Result<(), std::io::Error>> {
187        Pin::new(&mut self.writer).poll_close(cx)
188    }
189}