sequoia_openpgp/serialize/stream/
partial_body.rs1use std::fmt;
4use std::io;
5use std::cmp;
6
7use crate::Error;
8use crate::Result;
9use crate::packet::header::BodyLength;
10use crate::serialize::{
11 log2,
12 stream::{
13 writer,
14 Message,
15 Cookie,
16 },
17 write_byte,
18 Marshal,
19};
20
21pub struct PartialBodyFilter<'a, C: 'a> {
22 inner: Option<writer::BoxStack<'a, C>>,
30
31 cookie: C,
33
34 buffer: Vec<u8>,
36
37 buffer_threshold: usize,
39
40 max_chunk_size: usize,
43
44 position: u64,
46}
47assert_send_and_sync!(PartialBodyFilter<'_, C> where C);
48
49const PARTIAL_BODY_FILTER_MAX_CHUNK_SIZE : usize = 1 << 30;
50
51const PARTIAL_BODY_FILTER_BUFFER_THRESHOLD : usize = 4 * 1024 * 1024;
54
55impl<'a> PartialBodyFilter<'a, Cookie> {
56 pub fn new(inner: Message<'a>, cookie: Cookie)
58 -> Message<'a> {
59 Self::with_limits(inner, cookie,
60 PARTIAL_BODY_FILTER_BUFFER_THRESHOLD,
61 PARTIAL_BODY_FILTER_MAX_CHUNK_SIZE)
62 .expect("safe limits")
63 }
64
65 pub fn with_limits(inner: Message<'a>, cookie: Cookie,
67 buffer_threshold: usize,
68 max_chunk_size: usize)
69 -> Result<Message<'a>> {
70 if buffer_threshold.count_ones() != 1 {
71 return Err(Error::InvalidArgument(
72 "buffer_threshold is not a power of two".into()).into());
73 }
74
75 if max_chunk_size.count_ones() != 1 {
76 return Err(Error::InvalidArgument(
77 "max_chunk_size is not a power of two".into()).into());
78 }
79
80 if max_chunk_size > PARTIAL_BODY_FILTER_MAX_CHUNK_SIZE {
81 return Err(Error::InvalidArgument(
82 "max_chunk_size exceeds limit".into()).into());
83 }
84
85 Ok(Message::from(Box::new(PartialBodyFilter {
86 inner: Some(inner.into()),
87 cookie,
88 buffer: Vec::with_capacity(buffer_threshold),
89 buffer_threshold,
90 max_chunk_size,
91 position: 0,
92 })))
93 }
94}
95
96impl<'a, C: 'a> PartialBodyFilter<'a, C> {
97 fn write_out(&mut self, mut other: &[u8], done: bool)
103 -> io::Result<()> {
104 if self.inner.is_none() {
105 return Ok(());
106 }
107 let mut inner = self.inner.as_mut().unwrap();
108
109 if done {
110 let l = self.buffer.len() + other.len();
116 if l > std::u32::MAX as usize {
117 unimplemented!();
118 }
119 BodyLength::Full(l as u32).serialize(inner).map_err(
120 |e| match e.downcast::<io::Error>() {
121 Ok(err) => err,
123 Err(e) => io::Error::new(io::ErrorKind::Other, e),
125 })?;
126
127 inner.write_all(&self.buffer[..])?;
129 crate::vec_truncate(&mut self.buffer, 0);
130 inner.write_all(other)?;
131 } else {
132 while self.buffer.len() + other.len() > self.buffer_threshold {
133
134 let chunk_size_log2 =
136 log2(cmp::min(self.max_chunk_size,
137 self.buffer.len() + other.len())
138 as u32);
139 let chunk_size = (1usize) << chunk_size_log2;
140
141 let size = BodyLength::Partial(chunk_size as u32);
142 let mut size_byte = [0u8];
143 size.serialize(&mut io::Cursor::new(&mut size_byte[..]))
144 .expect("size should be representable");
145 let size_byte = size_byte[0];
146
147 write_byte(&mut inner, size_byte)?;
149
150 let l = cmp::min(self.buffer.len(), chunk_size);
152 inner.write_all(&self.buffer[..l])?;
153 crate::vec_drain_prefix(&mut self.buffer, l);
154
155 if chunk_size > l {
157 inner.write_all(&other[..chunk_size - l])?;
158 other = &other[chunk_size - l..];
159 }
160 }
161
162 self.buffer.extend_from_slice(other);
163 assert!(self.buffer.len() <= self.buffer_threshold);
164 }
165
166 Ok(())
167 }
168}
169
170impl<'a, C: 'a> io::Write for PartialBodyFilter<'a, C> {
171 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
172 if buf.len() >= self.buffer_threshold - self.buffer.len() {
174 self.write_out(buf, false)?;
175 } else {
176 self.buffer.append(buf.to_vec().as_mut());
177 }
178 self.position += buf.len() as u64;
179 Ok(buf.len())
180 }
181
182 fn flush(&mut self) -> io::Result<()> {
185 self.write_out(&b""[..], false)
186 }
187}
188
189impl<'a, C: 'a> fmt::Debug for PartialBodyFilter<'a, C> {
190 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
191 f.debug_struct("PartialBodyFilter")
192 .field("inner", &self.inner)
193 .finish()
194 }
195}
196
197impl<'a, C: 'a> writer::Stackable<'a, C> for PartialBodyFilter<'a, C> {
198 fn into_inner(mut self: Box<Self>) -> Result<Option<writer::BoxStack<'a, C>>> {
199 self.write_out(&b""[..], true)?;
200 Ok(self.inner.take())
201 }
202 fn pop(&mut self) -> Result<Option<writer::BoxStack<'a, C>>> {
203 self.write_out(&b""[..], true)?;
204 Ok(self.inner.take())
205 }
206 fn mount(&mut self, new: writer::BoxStack<'a, C>) {
207 self.inner = Some(new);
208 }
209 fn inner_mut(&mut self) -> Option<&mut (dyn writer::Stackable<'a, C> + Send + Sync)> {
210 if let Some(ref mut i) = self.inner {
211 Some(i)
212 } else {
213 None
214 }
215 }
216 fn inner_ref(&self) -> Option<&(dyn writer::Stackable<'a, C> + Send + Sync)> {
217 if let Some(ref i) = self.inner {
218 Some(i)
219 } else {
220 None
221 }
222 }
223 fn cookie_set(&mut self, cookie: C) -> C {
224 ::std::mem::replace(&mut self.cookie, cookie)
225 }
226 fn cookie_ref(&self) -> &C {
227 &self.cookie
228 }
229 fn cookie_mut(&mut self) -> &mut C {
230 &mut self.cookie
231 }
232 fn position(&self) -> u64 {
233 self.position
234 }
235}
236
237#[cfg(test)]
238mod test {
239 use std::io::Write;
240 use super::*;
241 use crate::serialize::stream::Message;
242
243 #[test]
244 fn basic() {
245 let mut buf = Vec::new();
246 {
247 let message = Message::new(&mut buf);
248 let mut pb = PartialBodyFilter::with_limits(
249 message, Default::default(),
250 16,
251 16)
252 .unwrap();
253 pb.write_all(b"0123").unwrap();
254 pb.write_all(b"4567").unwrap();
255 pb.finalize().unwrap();
256 }
257 assert_eq!(&buf,
258 &[8, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37]);
260 }
261
262 #[test]
263 fn no_avoidable_chunking() {
264 let mut buf = Vec::new();
265 {
266 let message = Message::new(&mut buf);
267 let mut pb = PartialBodyFilter::with_limits(
268 message, Default::default(),
269 4,
270 16)
271 .unwrap();
272 pb.write_all(b"01234567").unwrap();
273 pb.finalize().unwrap();
274 }
275 assert_eq!(&buf,
276 &[0xe0 + 3, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
278 0, ]);
280 }
281
282 #[test]
283 fn write_exceeding_buffer_threshold() {
284 let mut buf = Vec::new();
285 {
286 let message = Message::new(&mut buf);
287 let mut pb = PartialBodyFilter::with_limits(
288 message, Default::default(),
289 8,
290 16)
291 .unwrap();
292 pb.write_all(b"012345670123456701234567").unwrap();
293 pb.finalize().unwrap();
294 }
295 assert_eq!(&buf,
296 &[0xe0 + 4, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
298 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
299 8, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37]);
301 }
302}