Skip to main content

shell_download/
sink.rs

1use std::fs::File;
2use std::io::{self, Read, Write};
3use std::process::Command;
4use std::sync::{Arc, Mutex};
5use std::thread;
6
7use crate::ResponseError;
8
9/// Where a backend writes the downloaded response body.
10///
11/// Child processes (or a [`std::io::PipeReader`]) stream bytes into this sink from a worker thread
12/// via [`DownloadSink::spawn_stdout_drain`]. If the stream begins with the gzip magic bytes,
13/// bytes are piped through `gzip -dc` while copying (streaming decompress).
14#[derive(Clone, Debug)]
15pub struct DownloadSink {
16    inner: SinkInner,
17}
18
19#[derive(Clone, Debug)]
20enum SinkInner {
21    File(Arc<Mutex<Option<File>>>),
22    Buffer(Arc<Mutex<Vec<u8>>>),
23}
24
25fn sniff_is_gzip(n: usize, peek: &[u8; 2]) -> bool {
26    n == 2 && peek[0] == 0x1f && peek[1] == 0x8b
27}
28
29fn copy_gzip_stream<R: Read + Send + 'static>(
30    mut stream: R,
31    peek: [u8; 2],
32    mut out: impl Write,
33) -> Result<u64, ResponseError> {
34    let mut cmd = Command::new("gzip");
35    cmd.arg("-dc");
36    let (mut child, mut stdin, mut stdout, mut stderr) =
37        crate::process::spawn_stdin_stdout_stderr(&mut cmd).map_err(ResponseError::Io)?;
38
39    let feed = thread::spawn(move || -> io::Result<()> {
40        stdin.write_all(&peek)?;
41        io::copy(&mut stream, &mut stdin)?;
42        Ok(())
43    });
44
45    let stderr_join = thread::spawn(move || {
46        let mut buf = Vec::new();
47        let _ = stderr.read_to_end(&mut buf);
48        buf
49    });
50
51    let copied = io::copy(&mut stdout, &mut out).map_err(ResponseError::Io)?;
52
53    match feed.join() {
54        Ok(Ok(())) => {}
55        Ok(Err(e)) => {
56            let _ = child.wait();
57            return Err(ResponseError::Io(e));
58        }
59        Err(_) => {
60            let _ = child.wait();
61            return Err(ResponseError::ThreadPanicked);
62        }
63    }
64
65    let status = child.wait().map_err(ResponseError::Io)?;
66    let stderr_bytes = stderr_join.join().unwrap_or_default();
67    if !status.success() {
68        return Err(ResponseError::GzipFailed {
69            exit_code: status.code(),
70            stderr: String::from_utf8_lossy(&stderr_bytes).to_string(),
71        });
72    }
73    Ok(copied)
74}
75
76fn copy_stream_maybe_gunzip<R: Read + Send + 'static>(
77    mut stream: R,
78    mut out: impl Write,
79) -> Result<u64, ResponseError> {
80    let mut peek = [0u8; 2];
81    let n = stream.read(&mut peek).map_err(ResponseError::Io)?;
82    if sniff_is_gzip(n, &peek) {
83        return copy_gzip_stream(stream, peek, out);
84    }
85    out.write_all(&peek[..n]).map_err(ResponseError::Io)?;
86    io::copy(&mut stream, &mut out).map_err(ResponseError::Io)
87}
88
89impl DownloadSink {
90    /// Write the body directly to a file.
91    pub fn file(target_file: File) -> Self {
92        Self {
93            inner: SinkInner::File(Arc::new(Mutex::new(Some(target_file)))),
94        }
95    }
96
97    /// Accumulate the decompressed body in memory (gzipped payloads are expanded while streaming).
98    pub fn buffer(buffer: Arc<Mutex<Vec<u8>>>) -> Self {
99        Self {
100            inner: SinkInner::Buffer(buffer),
101        }
102    }
103
104    /// Spawn a thread that reads `stream` into this sink (file or buffer).
105    pub(crate) fn spawn_stdout_drain(
106        self,
107        stream: impl Read + Send + 'static,
108    ) -> thread::JoinHandle<Result<u64, ResponseError>> {
109        thread::spawn(move || match self.inner {
110            SinkInner::File(file) => {
111                let mut guard = file.lock().unwrap();
112                let mut f = guard.take().expect("file already finalized");
113                copy_stream_maybe_gunzip(stream, &mut f)
114            }
115            SinkInner::Buffer(buf) => {
116                let mut g = buf.lock().unwrap();
117                copy_stream_maybe_gunzip(stream, &mut *g)
118            }
119        })
120    }
121}