spacetimedb_fs_utils/
compression.rs

1use std::fs::File;
2use std::io;
3use std::io::{BufReader, Read, Seek, SeekFrom};
4use tokio::io::AsyncSeek;
5use zstd_framed;
6use zstd_framed::{ZstdReader, ZstdWriter};
7
8pub const ZSTD_MAGIC_BYTES: [u8; 4] = [0x28, 0xB5, 0x2F, 0xFD];
9
10/// Helper struct to keep track of the number of files compressed using each algorithm
11#[derive(Debug, Copy, Clone, PartialEq, Default)]
12pub struct CompressCount {
13    pub none: usize,
14    pub zstd: usize,
15}
16
17/// Compression type
18///
19/// if `None`, the file is not compressed, otherwise it will be compressed using the specified algorithm.
20#[derive(Debug, Copy, Clone, PartialEq)]
21pub enum CompressType {
22    None,
23    Zstd,
24}
25
26/// A reader that can read compressed files
27pub enum CompressReader {
28    None(BufReader<File>),
29    Zstd(Box<ZstdReader<'static, BufReader<File>>>),
30}
31
32impl CompressReader {
33    /// Create a new CompressReader from a File
34    ///
35    /// It will detect the compression type using `magic bytes` and return the appropriate reader.
36    ///
37    /// **Note**: The reader will be return to the original position after detecting the compression type.
38    pub fn new(mut inner: File) -> io::Result<Self> {
39        let current_pos = inner.stream_position()?;
40
41        let mut magic_bytes = [0u8; 4];
42        let bytes_read = inner.read(&mut magic_bytes)?;
43
44        // Restore the original position
45        inner.seek(SeekFrom::Start(current_pos))?;
46
47        // Determine compression type
48        Ok(if bytes_read == 4 {
49            match magic_bytes {
50                ZSTD_MAGIC_BYTES => {
51                    let table = zstd_framed::table::read_seek_table(&mut inner)?;
52                    let mut builder = ZstdReader::builder(inner);
53                    if let Some(table) = table {
54                        builder = builder.with_seek_table(table);
55                    }
56                    CompressReader::Zstd(Box::new(builder.build()?))
57                }
58                _ => CompressReader::None(BufReader::new(inner)),
59            }
60        } else {
61            CompressReader::None(BufReader::new(inner))
62        })
63    }
64
65    pub fn file_size(&self) -> io::Result<usize> {
66        Ok(match self {
67            Self::None(inner) => inner.get_ref().metadata()?.len() as usize,
68            //TODO: Can't see how to get the file size from ZstdReader
69            Self::Zstd(_inner) => 0,
70        })
71    }
72
73    pub fn compress_type(&self) -> CompressType {
74        match self {
75            CompressReader::None(_) => CompressType::None,
76            CompressReader::Zstd(_) => CompressType::Zstd,
77        }
78    }
79}
80
81impl Read for CompressReader {
82    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
83        match self {
84            CompressReader::None(inner) => inner.read(buf),
85            CompressReader::Zstd(inner) => inner.read(buf),
86        }
87    }
88}
89
90impl io::BufRead for CompressReader {
91    fn fill_buf(&mut self) -> io::Result<&[u8]> {
92        match self {
93            CompressReader::None(inner) => inner.fill_buf(),
94            CompressReader::Zstd(inner) => inner.fill_buf(),
95        }
96    }
97
98    fn consume(&mut self, amt: usize) {
99        match self {
100            CompressReader::None(inner) => inner.consume(amt),
101            CompressReader::Zstd(inner) => inner.consume(amt),
102        }
103    }
104}
105
106impl Seek for CompressReader {
107    fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
108        match self {
109            CompressReader::None(inner) => inner.seek(pos),
110            CompressReader::Zstd(inner) => inner.seek(pos),
111        }
112    }
113}
114
115pub fn new_zstd_writer<'a, W: io::Write>(inner: W, max_frame_size: Option<u32>) -> io::Result<ZstdWriter<'a, W>> {
116    let writer = ZstdWriter::builder(inner).with_compression_level(0);
117    if let Some(max_frame_size) = max_frame_size {
118        writer.with_seek_table(max_frame_size)
119    } else {
120        writer
121    }
122    .build()
123}
124
125pub fn compress_with_zstd<W: io::Write, R: io::Read>(
126    mut src: R,
127    mut dst: W,
128    max_frame_size: Option<u32>,
129) -> io::Result<()> {
130    let mut writer = new_zstd_writer(&mut dst, max_frame_size)?;
131    io::copy(&mut src, &mut writer)?;
132    writer.shutdown()?;
133    drop(writer);
134    Ok(())
135}
136
137pub use async_impls::AsyncCompressReader;
138
139pub async fn segment_len<T: AsyncSeek + Unpin>(r: &mut T) -> tokio::io::Result<u64> {
140    use tokio::io::AsyncSeekExt;
141    let old_pos = r.stream_position().await?;
142    let len = r.seek(tokio::io::SeekFrom::End(0)).await?;
143    // If we're already at the end of the file, avoid seeking.
144    if old_pos != len {
145        r.seek(tokio::io::SeekFrom::Start(old_pos)).await?;
146    }
147
148    Ok(len)
149}
150
151mod async_impls {
152    use super::*;
153    use std::pin::Pin;
154    use std::task::{Context, Poll};
155    use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt};
156    use zstd_framed::{AsyncZstdReader, AsyncZstdSeekableReader};
157
158    pub enum AsyncCompressReader<R> {
159        None(io::BufReader<R>),
160        Zstd(Box<AsyncZstdSeekableReader<'static, io::BufReader<R>>>),
161    }
162
163    impl<R: AsyncRead + AsyncSeek + Unpin> AsyncCompressReader<R> {
164        /// Create a new AsyncCompressReader from a reader
165        ///
166        /// It will detect the compression type using `magic bytes` and return the appropriate reader.
167        ///
168        /// **Note**: The reader will be return to the start after detecting the compression type.
169        pub async fn new(mut inner: R) -> io::Result<Self> {
170            let mut magic_bytes = [0u8; 4];
171            let magic_bytes = match inner.read_exact(&mut magic_bytes).await {
172                Ok(_) => Some(magic_bytes),
173                Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => None,
174                Err(e) => return Err(e),
175            };
176
177            // Restore the original position
178            inner.seek(io::SeekFrom::Start(0)).await?;
179
180            // Determine compression type
181            Ok(match magic_bytes {
182                Some(ZSTD_MAGIC_BYTES) => {
183                    let table = zstd_framed::table::tokio::read_seek_table(&mut inner).await?;
184                    let mut builder = AsyncZstdReader::builder_tokio(inner);
185                    if let Some(table) = table {
186                        builder = builder.with_seek_table(table);
187                    }
188                    AsyncCompressReader::Zstd(Box::new(builder.build()?.seekable()))
189                }
190                _ => AsyncCompressReader::None(io::BufReader::new(inner)),
191            })
192        }
193
194        pub fn compress_type(&self) -> CompressType {
195            match self {
196                AsyncCompressReader::None(_) => CompressType::None,
197                AsyncCompressReader::Zstd(_) => CompressType::Zstd,
198            }
199        }
200    }
201
202    impl AsyncCompressReader<tokio::fs::File> {
203        pub async fn file_size(&mut self) -> io::Result<u64> {
204            match self {
205                AsyncCompressReader::None(inner) => inner.get_ref().metadata().await.map(|m| m.len()),
206                AsyncCompressReader::Zstd(inner) => segment_len(inner).await,
207            }
208        }
209    }
210    macro_rules! forward_reader {
211    ($self:ident.$method:ident($($args:expr),*)) => {
212        match $self.get_mut() {
213            AsyncCompressReader::None(r) => Pin::new(r).$method($($args),*),
214            AsyncCompressReader::Zstd(r) => Pin::new(r).$method($($args),*),
215        }
216    };
217}
218    impl<R: AsyncRead + AsyncSeek + Unpin> AsyncRead for AsyncCompressReader<R> {
219        fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut io::ReadBuf<'_>) -> Poll<io::Result<()>> {
220            forward_reader!(self.poll_read(cx, buf))
221        }
222    }
223    impl<R: AsyncRead + AsyncSeek + Unpin> AsyncBufRead for AsyncCompressReader<R> {
224        fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
225            forward_reader!(self.poll_fill_buf(cx))
226        }
227
228        fn consume(self: Pin<&mut Self>, amt: usize) {
229            forward_reader!(self.consume(amt))
230        }
231    }
232    impl<R: AsyncRead + AsyncSeek + Unpin> AsyncSeek for AsyncCompressReader<R> {
233        fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> std::io::Result<()> {
234            forward_reader!(self.start_seek(position))
235        }
236
237        fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<u64>> {
238            forward_reader!(self.poll_complete(cx))
239        }
240    }
241}