taskcluster_download/
factory.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use std::io::{Cursor, SeekFrom};
4use tokio::fs::File;
5use tokio::io::{AsyncSeekExt, AsyncWrite, AsyncWriteExt};
6
7/// An AsyncWriterFactory can produce, on demand, an [AsyncWrite] object.  In the event of a
8/// download failure, the restarted download will use a fresh writer to restart writing at the
9/// beginning.
10#[async_trait]
11pub trait AsyncWriterFactory {
12    /// Get a fresh [AsyncWrite] object, positioned at the point where downloaded data should
13    /// be written.
14    async fn get_writer<'a>(&'a mut self) -> Result<Box<dyn AsyncWrite + Unpin + 'a>>;
15}
16
17/// A CusorWriterFactory creates [AsyncWrite] objects from a [std::io::Cursor], allowing
18/// downloads to in-memory buffers.  It is specialized for [Vec<u8>] (which grows indefinitely)
19/// and `&mut [u8]` (which has a fixed maximum size)
20pub struct CursorWriterFactory<T>(Cursor<T>);
21
22#[async_trait]
23impl AsyncWriterFactory for CursorWriterFactory<Vec<u8>> {
24    async fn get_writer<'a>(&'a mut self) -> Result<Box<dyn AsyncWrite + Unpin + 'a>> {
25        self.0.get_mut().clear();
26        self.0.set_position(0);
27        Ok(Box::new(&mut self.0))
28    }
29}
30
31#[async_trait]
32impl AsyncWriterFactory for CursorWriterFactory<&mut [u8]> {
33    async fn get_writer<'a>(&'a mut self) -> Result<Box<dyn AsyncWrite + Unpin + 'a>> {
34        self.0.set_position(0);
35        Ok(Box::new(&mut self.0))
36    }
37}
38
39impl Default for CursorWriterFactory<Vec<u8>> {
40    fn default() -> Self {
41        Self(Cursor::new(Vec::new()))
42    }
43}
44
45impl CursorWriterFactory<Vec<u8>> {
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    /// Consume the factory, returning the vector into which the data was read
51    pub fn into_inner(self) -> Vec<u8> {
52        self.0.into_inner()
53    }
54}
55
56impl<'a> CursorWriterFactory<&'a mut [u8]> {
57    pub fn for_buf(inner: &'a mut [u8]) -> Self {
58        Self(Cursor::new(inner))
59    }
60
61    /// Return the size of the data written to the buffer.  This value should
62    /// be used to slice the resulting data from the buffer.
63    pub fn size(self) -> usize {
64        self.0.position() as usize
65    }
66}
67
68/// A FileWriterFactory creates [AsyncWrite] objects by rewinding and cloning a [tokio::fs::File].
69/// The file must be open in write mode and must be clone-able (that is, [File::try_clone()] must
70/// succeed) in order to support retried uploads.
71pub struct FileWriterFactory(File);
72
73#[async_trait]
74impl AsyncWriterFactory for FileWriterFactory {
75    async fn get_writer<'a>(&'a mut self) -> Result<Box<dyn AsyncWrite + Unpin + 'a>> {
76        let mut file = self.0.try_clone().await?;
77        file.set_len(0).await?;
78        file.seek(SeekFrom::Start(0)).await?;
79        Ok(Box::new(file))
80    }
81}
82
83impl FileWriterFactory {
84    pub fn new(file: File) -> Self {
85        Self(file)
86    }
87
88    /// Return the File, after finishing any concurrent async operations.  The
89    /// file posiion is unspecified.
90    pub async fn into_inner(mut self) -> Result<File> {
91        self.0.flush().await?;
92        Ok(self.0)
93    }
94}
95
96#[cfg(test)]
97mod test {
98    use super::*;
99    use anyhow::Result;
100    use tempfile::tempfile;
101    use tokio::io::{copy, AsyncReadExt, AsyncSeekExt};
102
103    const DATA: &[u8] = b"HELLO/WORLD";
104
105    async fn copy_to_factory<F: AsyncWriterFactory>(
106        data: &[u8],
107        factory: &mut F,
108    ) -> std::io::Result<()> {
109        let mut reader = Cursor::new(data);
110        let mut writer = factory.get_writer().await.unwrap();
111        copy(&mut reader, &mut writer).await?;
112        Ok(())
113    }
114
115    #[tokio::test]
116    async fn vec_cursor_writer_twice() -> Result<()> {
117        let mut factory = CursorWriterFactory::new();
118        copy_to_factory(b"wrong data, shouldn't see this", &mut factory).await?;
119        copy_to_factory(DATA, &mut factory).await?;
120        assert_eq!(&factory.into_inner(), DATA);
121        Ok(())
122    }
123
124    #[tokio::test]
125    async fn buf_cursor_writer_twice() -> Result<()> {
126        let mut buf = [0u8; 256];
127        let mut factory = CursorWriterFactory::for_buf(&mut buf[..]);
128        copy_to_factory(b"nobody should see this", &mut factory).await?;
129        copy_to_factory(DATA, &mut factory).await?;
130        let size = factory.size();
131        assert_eq!(&buf[..size], DATA);
132        Ok(())
133    }
134
135    #[tokio::test]
136    async fn buf_cursor_writer_too_small() -> Result<()> {
137        let mut buf = [0u8; 5];
138        let mut factory = CursorWriterFactory::for_buf(&mut buf[..]);
139        let err = copy_to_factory(DATA, &mut factory).await.unwrap_err();
140        assert_eq!(err.kind(), std::io::ErrorKind::WriteZero);
141        Ok(())
142    }
143
144    #[tokio::test]
145    async fn file_writer_twice() -> Result<()> {
146        let mut factory = FileWriterFactory::new(tempfile()?.into());
147        copy_to_factory(b"wrong data, shouldn't see this", &mut factory).await?;
148        copy_to_factory(DATA, &mut factory).await?;
149
150        let mut file = factory.into_inner().await?;
151        file.seek(SeekFrom::Start(0)).await?;
152
153        let mut res = Vec::new();
154        file.read_to_end(&mut res).await?;
155        assert_eq!(&res, DATA);
156        Ok(())
157    }
158}