taskcluster_download/
factory.rs1use anyhow::Result;
2use async_trait::async_trait;
3use std::io::{Cursor, SeekFrom};
4use tokio::fs::File;
5use tokio::io::{AsyncSeekExt, AsyncWrite, AsyncWriteExt};
6
7#[async_trait]
11pub trait AsyncWriterFactory {
12 async fn get_writer<'a>(&'a mut self) -> Result<Box<dyn AsyncWrite + Unpin + 'a>>;
15}
16
17pub struct CursorWriterFactory<T>(Cursor<T>);
21
22#[async_trait]
23impl AsyncWriterFactory for CursorWriterFactory<Vec<u8>> {
24 async fn get_writer<'a>(&'a mut self) -> Result<Box<dyn AsyncWrite + Unpin + 'a>> {
25 self.0.get_mut().clear();
26 self.0.set_position(0);
27 Ok(Box::new(&mut self.0))
28 }
29}
30
31#[async_trait]
32impl AsyncWriterFactory for CursorWriterFactory<&mut [u8]> {
33 async fn get_writer<'a>(&'a mut self) -> Result<Box<dyn AsyncWrite + Unpin + 'a>> {
34 self.0.set_position(0);
35 Ok(Box::new(&mut self.0))
36 }
37}
38
39impl Default for CursorWriterFactory<Vec<u8>> {
40 fn default() -> Self {
41 Self(Cursor::new(Vec::new()))
42 }
43}
44
45impl CursorWriterFactory<Vec<u8>> {
46 pub fn new() -> Self {
47 Self::default()
48 }
49
50 pub fn into_inner(self) -> Vec<u8> {
52 self.0.into_inner()
53 }
54}
55
56impl<'a> CursorWriterFactory<&'a mut [u8]> {
57 pub fn for_buf(inner: &'a mut [u8]) -> Self {
58 Self(Cursor::new(inner))
59 }
60
61 pub fn size(self) -> usize {
64 self.0.position() as usize
65 }
66}
67
68pub struct FileWriterFactory(File);
72
73#[async_trait]
74impl AsyncWriterFactory for FileWriterFactory {
75 async fn get_writer<'a>(&'a mut self) -> Result<Box<dyn AsyncWrite + Unpin + 'a>> {
76 let mut file = self.0.try_clone().await?;
77 file.set_len(0).await?;
78 file.seek(SeekFrom::Start(0)).await?;
79 Ok(Box::new(file))
80 }
81}
82
83impl FileWriterFactory {
84 pub fn new(file: File) -> Self {
85 Self(file)
86 }
87
88 pub async fn into_inner(mut self) -> Result<File> {
91 self.0.flush().await?;
92 Ok(self.0)
93 }
94}
95
96#[cfg(test)]
97mod test {
98 use super::*;
99 use anyhow::Result;
100 use tempfile::tempfile;
101 use tokio::io::{copy, AsyncReadExt, AsyncSeekExt};
102
103 const DATA: &[u8] = b"HELLO/WORLD";
104
105 async fn copy_to_factory<F: AsyncWriterFactory>(
106 data: &[u8],
107 factory: &mut F,
108 ) -> std::io::Result<()> {
109 let mut reader = Cursor::new(data);
110 let mut writer = factory.get_writer().await.unwrap();
111 copy(&mut reader, &mut writer).await?;
112 Ok(())
113 }
114
115 #[tokio::test]
116 async fn vec_cursor_writer_twice() -> Result<()> {
117 let mut factory = CursorWriterFactory::new();
118 copy_to_factory(b"wrong data, shouldn't see this", &mut factory).await?;
119 copy_to_factory(DATA, &mut factory).await?;
120 assert_eq!(&factory.into_inner(), DATA);
121 Ok(())
122 }
123
124 #[tokio::test]
125 async fn buf_cursor_writer_twice() -> Result<()> {
126 let mut buf = [0u8; 256];
127 let mut factory = CursorWriterFactory::for_buf(&mut buf[..]);
128 copy_to_factory(b"nobody should see this", &mut factory).await?;
129 copy_to_factory(DATA, &mut factory).await?;
130 let size = factory.size();
131 assert_eq!(&buf[..size], DATA);
132 Ok(())
133 }
134
135 #[tokio::test]
136 async fn buf_cursor_writer_too_small() -> Result<()> {
137 let mut buf = [0u8; 5];
138 let mut factory = CursorWriterFactory::for_buf(&mut buf[..]);
139 let err = copy_to_factory(DATA, &mut factory).await.unwrap_err();
140 assert_eq!(err.kind(), std::io::ErrorKind::WriteZero);
141 Ok(())
142 }
143
144 #[tokio::test]
145 async fn file_writer_twice() -> Result<()> {
146 let mut factory = FileWriterFactory::new(tempfile()?.into());
147 copy_to_factory(b"wrong data, shouldn't see this", &mut factory).await?;
148 copy_to_factory(DATA, &mut factory).await?;
149
150 let mut file = factory.into_inner().await?;
151 file.seek(SeekFrom::Start(0)).await?;
152
153 let mut res = Vec::new();
154 file.read_to_end(&mut res).await?;
155 assert_eq!(&res, DATA);
156 Ok(())
157 }
158}