spacetimedb_fs_utils/
compression.rs1use std::fs::File;
2use std::io;
3use std::io::{BufReader, Read, Seek, SeekFrom};
4use tokio::io::AsyncSeek;
5use zstd_framed;
6use zstd_framed::{ZstdReader, ZstdWriter};
7
8pub const ZSTD_MAGIC_BYTES: [u8; 4] = [0x28, 0xB5, 0x2F, 0xFD];
9
10#[derive(Debug, Copy, Clone, PartialEq, Default)]
12pub struct CompressCount {
13 pub none: usize,
14 pub zstd: usize,
15}
16
17#[derive(Debug, Copy, Clone, PartialEq)]
21pub enum CompressType {
22 None,
23 Zstd,
24}
25
26pub enum CompressReader {
28 None(BufReader<File>),
29 Zstd(Box<ZstdReader<'static, BufReader<File>>>),
30}
31
32impl CompressReader {
33 pub fn new(mut inner: File) -> io::Result<Self> {
39 let current_pos = inner.stream_position()?;
40
41 let mut magic_bytes = [0u8; 4];
42 let bytes_read = inner.read(&mut magic_bytes)?;
43
44 inner.seek(SeekFrom::Start(current_pos))?;
46
47 Ok(if bytes_read == 4 {
49 match magic_bytes {
50 ZSTD_MAGIC_BYTES => {
51 let table = zstd_framed::table::read_seek_table(&mut inner)?;
52 let mut builder = ZstdReader::builder(inner);
53 if let Some(table) = table {
54 builder = builder.with_seek_table(table);
55 }
56 CompressReader::Zstd(Box::new(builder.build()?))
57 }
58 _ => CompressReader::None(BufReader::new(inner)),
59 }
60 } else {
61 CompressReader::None(BufReader::new(inner))
62 })
63 }
64
65 pub fn file_size(&self) -> io::Result<usize> {
66 Ok(match self {
67 Self::None(inner) => inner.get_ref().metadata()?.len() as usize,
68 Self::Zstd(_inner) => 0,
70 })
71 }
72
73 pub fn compress_type(&self) -> CompressType {
74 match self {
75 CompressReader::None(_) => CompressType::None,
76 CompressReader::Zstd(_) => CompressType::Zstd,
77 }
78 }
79}
80
81impl Read for CompressReader {
82 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
83 match self {
84 CompressReader::None(inner) => inner.read(buf),
85 CompressReader::Zstd(inner) => inner.read(buf),
86 }
87 }
88}
89
90impl io::BufRead for CompressReader {
91 fn fill_buf(&mut self) -> io::Result<&[u8]> {
92 match self {
93 CompressReader::None(inner) => inner.fill_buf(),
94 CompressReader::Zstd(inner) => inner.fill_buf(),
95 }
96 }
97
98 fn consume(&mut self, amt: usize) {
99 match self {
100 CompressReader::None(inner) => inner.consume(amt),
101 CompressReader::Zstd(inner) => inner.consume(amt),
102 }
103 }
104}
105
106impl Seek for CompressReader {
107 fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
108 match self {
109 CompressReader::None(inner) => inner.seek(pos),
110 CompressReader::Zstd(inner) => inner.seek(pos),
111 }
112 }
113}
114
115pub fn new_zstd_writer<'a, W: io::Write>(inner: W, max_frame_size: Option<u32>) -> io::Result<ZstdWriter<'a, W>> {
116 let writer = ZstdWriter::builder(inner).with_compression_level(0);
117 if let Some(max_frame_size) = max_frame_size {
118 writer.with_seek_table(max_frame_size)
119 } else {
120 writer
121 }
122 .build()
123}
124
125pub fn compress_with_zstd<W: io::Write, R: io::Read>(
126 mut src: R,
127 mut dst: W,
128 max_frame_size: Option<u32>,
129) -> io::Result<()> {
130 let mut writer = new_zstd_writer(&mut dst, max_frame_size)?;
131 io::copy(&mut src, &mut writer)?;
132 writer.shutdown()?;
133 drop(writer);
134 Ok(())
135}
136
137pub use async_impls::AsyncCompressReader;
138
139pub async fn segment_len<T: AsyncSeek + Unpin>(r: &mut T) -> tokio::io::Result<u64> {
140 use tokio::io::AsyncSeekExt;
141 let old_pos = r.stream_position().await?;
142 let len = r.seek(tokio::io::SeekFrom::End(0)).await?;
143 if old_pos != len {
145 r.seek(tokio::io::SeekFrom::Start(old_pos)).await?;
146 }
147
148 Ok(len)
149}
150
151mod async_impls {
152 use super::*;
153 use std::pin::Pin;
154 use std::task::{Context, Poll};
155 use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt};
156 use zstd_framed::{AsyncZstdReader, AsyncZstdSeekableReader};
157
158 pub enum AsyncCompressReader<R> {
159 None(io::BufReader<R>),
160 Zstd(Box<AsyncZstdSeekableReader<'static, io::BufReader<R>>>),
161 }
162
163 impl<R: AsyncRead + AsyncSeek + Unpin> AsyncCompressReader<R> {
164 pub async fn new(mut inner: R) -> io::Result<Self> {
170 let mut magic_bytes = [0u8; 4];
171 let magic_bytes = match inner.read_exact(&mut magic_bytes).await {
172 Ok(_) => Some(magic_bytes),
173 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => None,
174 Err(e) => return Err(e),
175 };
176
177 inner.seek(io::SeekFrom::Start(0)).await?;
179
180 Ok(match magic_bytes {
182 Some(ZSTD_MAGIC_BYTES) => {
183 let table = zstd_framed::table::tokio::read_seek_table(&mut inner).await?;
184 let mut builder = AsyncZstdReader::builder_tokio(inner);
185 if let Some(table) = table {
186 builder = builder.with_seek_table(table);
187 }
188 AsyncCompressReader::Zstd(Box::new(builder.build()?.seekable()))
189 }
190 _ => AsyncCompressReader::None(io::BufReader::new(inner)),
191 })
192 }
193
194 pub fn compress_type(&self) -> CompressType {
195 match self {
196 AsyncCompressReader::None(_) => CompressType::None,
197 AsyncCompressReader::Zstd(_) => CompressType::Zstd,
198 }
199 }
200 }
201
202 impl AsyncCompressReader<tokio::fs::File> {
203 pub async fn file_size(&mut self) -> io::Result<u64> {
204 match self {
205 AsyncCompressReader::None(inner) => inner.get_ref().metadata().await.map(|m| m.len()),
206 AsyncCompressReader::Zstd(inner) => segment_len(inner).await,
207 }
208 }
209 }
210 macro_rules! forward_reader {
211 ($self:ident.$method:ident($($args:expr),*)) => {
212 match $self.get_mut() {
213 AsyncCompressReader::None(r) => Pin::new(r).$method($($args),*),
214 AsyncCompressReader::Zstd(r) => Pin::new(r).$method($($args),*),
215 }
216 };
217}
218 impl<R: AsyncRead + AsyncSeek + Unpin> AsyncRead for AsyncCompressReader<R> {
219 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut io::ReadBuf<'_>) -> Poll<io::Result<()>> {
220 forward_reader!(self.poll_read(cx, buf))
221 }
222 }
223 impl<R: AsyncRead + AsyncSeek + Unpin> AsyncBufRead for AsyncCompressReader<R> {
224 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
225 forward_reader!(self.poll_fill_buf(cx))
226 }
227
228 fn consume(self: Pin<&mut Self>, amt: usize) {
229 forward_reader!(self.consume(amt))
230 }
231 }
232 impl<R: AsyncRead + AsyncSeek + Unpin> AsyncSeek for AsyncCompressReader<R> {
233 fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> std::io::Result<()> {
234 forward_reader!(self.start_seek(position))
235 }
236
237 fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<u64>> {
238 forward_reader!(self.poll_complete(cx))
239 }
240 }
241}