1use std::io::{self, BufWriter, Read, Seek, SeekFrom, Write};
13use std::path::Path;
14
15use tempfile::NamedTempFile;
16
17const READ_CHUNK_SIZE: usize = 64 * 1024;
20
21pub enum Buffer {
23 InMemory(Vec<u8>),
25 Spilled {
27 writer: BufWriter<NamedTempFile>,
28 len: u64,
30 },
31}
32
33impl Buffer {
34 pub fn new() -> Self {
36 Buffer::InMemory(Vec::new())
37 }
38
39 pub fn len(&self) -> u64 {
41 match self {
42 Buffer::InMemory(v) => v.len() as u64,
43 Buffer::Spilled { len, .. } => *len,
44 }
45 }
46
47 pub fn is_empty(&self) -> bool {
49 self.len() == 0
50 }
51
52 pub fn drain_reader<R: Read>(
64 &mut self,
65 mut reader: R,
66 threshold: usize,
67 spill_dir: &Path,
68 ) -> io::Result<()> {
69 let mut chunk = vec![0u8; READ_CHUNK_SIZE];
70 loop {
71 #[cfg(feature = "cli")]
72 if crate::signal::is_cancelled() {
73 return Err(io::Error::new(
74 io::ErrorKind::Interrupted,
75 "rusty-sponge: cancelled by signal",
76 ));
77 }
78 let n = reader.read(&mut chunk)?;
79 if n == 0 {
80 break;
81 }
82 self.append(&chunk[..n], threshold, spill_dir)?;
83 }
84 Ok(())
85 }
86
87 pub fn append(&mut self, bytes: &[u8], threshold: usize, spill_dir: &Path) -> io::Result<()> {
90 let threshold_u64 = threshold as u64;
91 let projected_len = self.len() + bytes.len() as u64;
92
93 if matches!(self, Buffer::InMemory(_)) && projected_len > threshold_u64 {
96 self.transition_to_spilled(spill_dir)?;
97 }
98
99 match self {
100 Buffer::InMemory(v) => v.extend_from_slice(bytes),
101 Buffer::Spilled { writer, len } => {
102 writer.write_all(bytes)?;
103 *len += bytes.len() as u64;
104 }
105 }
106 Ok(())
107 }
108
109 pub fn transition_to_spilled(&mut self, spill_dir: &Path) -> io::Result<()> {
112 if let Buffer::InMemory(bytes) = std::mem::replace(self, Buffer::InMemory(Vec::new())) {
113 let tempfile = tempfile::Builder::new()
114 .prefix(".rusty-sponge-spill-")
115 .tempfile_in(spill_dir)?;
116 let mut writer = BufWriter::with_capacity(READ_CHUNK_SIZE, tempfile);
117 writer.write_all(&bytes)?;
118 let len = bytes.len() as u64;
119 *self = Buffer::Spilled { writer, len };
120 }
121 Ok(())
122 }
123
124 pub fn write_to<W: Write>(self, mut out: W) -> io::Result<()> {
129 match self {
130 Buffer::InMemory(v) => out.write_all(&v),
131 Buffer::Spilled { writer, .. } => {
132 let mut tempfile = writer
133 .into_inner()
134 .map_err(|e| io::Error::other(format!("BufWriter flush failed: {e}")))?;
135 tempfile.as_file_mut().seek(SeekFrom::Start(0))?;
136 let mut chunk = vec![0u8; READ_CHUNK_SIZE];
137 let mut reader = tempfile.as_file();
138 loop {
139 let n = reader.read(&mut chunk)?;
140 if n == 0 {
141 break;
142 }
143 out.write_all(&chunk[..n])?;
144 }
145 Ok(())
146 }
147 }
148 }
149}
150
151impl Default for Buffer {
152 fn default() -> Self {
153 Self::new()
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use std::io::Cursor;
161
162 #[test]
163 fn empty_buffer_has_len_zero() {
164 let buf = Buffer::new();
165 assert_eq!(buf.len(), 0);
166 }
167
168 #[test]
169 fn drain_small_input_stays_in_memory() {
170 let tmpdir = tempfile::tempdir().unwrap();
171 let mut buf = Buffer::new();
172 let input = Cursor::new(b"hello world\n");
173 buf.drain_reader(input, 1024 * 1024, tmpdir.path()).unwrap();
174 assert!(matches!(buf, Buffer::InMemory(_)));
175 assert_eq!(buf.len(), 12);
176 }
177
178 #[test]
179 fn drain_large_input_transitions_to_spilled() {
180 let tmpdir = tempfile::tempdir().unwrap();
181 let mut buf = Buffer::new();
182 let big = vec![0xAAu8; 256 * 1024];
184 buf.drain_reader(Cursor::new(&big), 64 * 1024, tmpdir.path())
185 .unwrap();
186 assert!(matches!(buf, Buffer::Spilled { .. }));
187 assert_eq!(buf.len(), 256 * 1024);
188 }
189
190 #[test]
191 fn write_to_roundtrips_in_memory() {
192 let tmpdir = tempfile::tempdir().unwrap();
193 let mut buf = Buffer::new();
194 buf.drain_reader(Cursor::new(b"abc\n"), 1024 * 1024, tmpdir.path())
195 .unwrap();
196 let mut out = Vec::new();
197 buf.write_to(&mut out).unwrap();
198 assert_eq!(out, b"abc\n");
199 }
200
201 #[test]
202 fn write_to_roundtrips_spilled() {
203 let tmpdir = tempfile::tempdir().unwrap();
204 let mut buf = Buffer::new();
205 let big = (0u8..=255u8).cycle().take(256 * 1024).collect::<Vec<_>>();
206 buf.drain_reader(Cursor::new(&big), 1024, tmpdir.path())
207 .unwrap();
208 assert!(matches!(buf, Buffer::Spilled { .. }));
209 let mut out = Vec::new();
210 buf.write_to(&mut out).unwrap();
211 assert_eq!(out, big);
212 }
213
214 #[test]
215 fn binary_bytes_pass_through_unchanged() {
216 let tmpdir = tempfile::tempdir().unwrap();
217 let mut buf = Buffer::new();
218 let bytes: &[u8] = &[0x00, 0xFE, 0xFF, 0xC3, 0x28, 0xA0, 0xA1];
219 buf.drain_reader(Cursor::new(bytes), 1024 * 1024, tmpdir.path())
220 .unwrap();
221 let mut out = Vec::new();
222 buf.write_to(&mut out).unwrap();
223 assert_eq!(out, bytes);
224 }
225
226 #[test]
227 fn empty_input_writes_zero_bytes() {
228 let tmpdir = tempfile::tempdir().unwrap();
229 let mut buf = Buffer::new();
230 buf.drain_reader(Cursor::new(&[][..]), 1024 * 1024, tmpdir.path())
231 .unwrap();
232 let mut out = Vec::new();
233 buf.write_to(&mut out).unwrap();
234 assert_eq!(out, Vec::<u8>::new());
235 }
236}