Skip to main content

sequoia_openpgp/serialize/stream/
partial_body.rs

1//! Encodes a byte stream using OpenPGP's partial body encoding.
2
3use std::fmt;
4use std::io;
5use std::cmp;
6
7use crate::Error;
8use crate::Result;
9use crate::packet::header::BodyLength;
10use crate::serialize::{
11    log2,
12    stream::{
13        writer,
14        Message,
15        Cookie,
16    },
17    write_byte,
18    Marshal,
19};
20
21pub struct PartialBodyFilter<'a, C: 'a> {
22    // The underlying writer.
23    //
24    // XXX: Opportunity for optimization.  Previously, this writer
25    // implemented `Drop`, so we could not move the inner writer out
26    // of this writer.  We therefore wrapped it with `Option` so that
27    // we can `take()` it.  This writer no longer implements Drop, so
28    // we could avoid the Option here.
29    inner: Option<writer::BoxStack<'a, C>>,
30
31    // The cookie.
32    cookie: C,
33
34    // The buffer.
35    buffer: Vec<u8>,
36
37    // The amount to buffer before flushing.
38    buffer_threshold: usize,
39
40    // The maximum size of a partial body chunk.  The standard allows
41    // for chunks up to 1 GB in size.
42    max_chunk_size: usize,
43
44    // The number of bytes written to this filter.
45    position: u64,
46}
47assert_send_and_sync!(PartialBodyFilter<'_, C> where C);
48
49const PARTIAL_BODY_FILTER_MAX_CHUNK_SIZE : usize = 1 << 30;
50
51// The amount to buffer before flushing.  If this is small, we get
52// lots of small partial body packets, which is annoying.
53const PARTIAL_BODY_FILTER_BUFFER_THRESHOLD : usize = 4 * 1024 * 1024;
54
55impl<'a> PartialBodyFilter<'a, Cookie> {
56    /// Returns a new partial body encoder.
57    pub fn new(inner: Message<'a>, cookie: Cookie)
58               -> Message<'a> {
59        Self::with_limits(inner, cookie,
60                          PARTIAL_BODY_FILTER_BUFFER_THRESHOLD,
61                          PARTIAL_BODY_FILTER_MAX_CHUNK_SIZE)
62            .expect("safe limits")
63    }
64
65    /// Returns a new partial body encoder with the given limits.
66    pub fn with_limits(inner: Message<'a>, cookie: Cookie,
67                       buffer_threshold: usize,
68                       max_chunk_size: usize)
69                       -> Result<Message<'a>> {
70        if buffer_threshold.count_ones() != 1 {
71            return Err(Error::InvalidArgument(
72                "buffer_threshold is not a power of two".into()).into());
73        }
74
75        if max_chunk_size.count_ones() != 1 {
76            return Err(Error::InvalidArgument(
77                "max_chunk_size is not a power of two".into()).into());
78        }
79
80        if max_chunk_size > PARTIAL_BODY_FILTER_MAX_CHUNK_SIZE {
81            return Err(Error::InvalidArgument(
82                "max_chunk_size exceeds limit".into()).into());
83        }
84
85        Ok(Message::from(Box::new(PartialBodyFilter {
86            inner: Some(inner.into()),
87            cookie,
88            buffer: Vec::with_capacity(buffer_threshold),
89            buffer_threshold,
90            max_chunk_size,
91            position: 0,
92        })))
93    }
94}
95
96impl<'a, C: 'a> PartialBodyFilter<'a, C> {
97    // Writes out any full chunks between `self.buffer` and `other`.
98    // Any extra data is buffered.
99    //
100    // If `done` is set, then flushes any data, and writes the end of
101    // the partial body encoding.
102    fn write_out(&mut self, mut other: &[u8], done: bool)
103                 -> io::Result<()> {
104        if self.inner.is_none() {
105            return Ok(());
106        }
107        let mut inner = self.inner.as_mut().unwrap();
108
109        if done {
110            // We're done.  The last header MUST be a non-partial body
111            // header.  We have to write it even if it is 0 bytes
112            // long.
113
114            // Write the header.
115            let l = self.buffer.len() + other.len();
116            if l > std::u32::MAX as usize {
117                unimplemented!();
118            }
119            BodyLength::Full(l as u32).serialize(inner).map_err(
120                |e| match e.downcast::<io::Error>() {
121                        // An io::Error.  Pass as-is.
122                        Ok(err) => err,
123                        // A failure.  Wrap it.
124                        Err(e) => io::Error::new(io::ErrorKind::Other, e),
125                    })?;
126
127            // Write the body.
128            inner.write_all(&self.buffer[..])?;
129            crate::vec_truncate(&mut self.buffer, 0);
130            inner.write_all(other)?;
131        } else {
132            while self.buffer.len() + other.len() > self.buffer_threshold {
133
134                // Write a partial body length header.
135                let chunk_size_log2 =
136                    log2(cmp::min(self.max_chunk_size,
137                                  self.buffer.len() + other.len())
138                         as u32);
139                let chunk_size = (1usize) << chunk_size_log2;
140
141                let size = BodyLength::Partial(chunk_size as u32);
142                let mut size_byte = [0u8];
143                size.serialize(&mut io::Cursor::new(&mut size_byte[..]))
144                    .expect("size should be representable");
145                let size_byte = size_byte[0];
146
147                // Write out the chunk...
148                write_byte(&mut inner, size_byte)?;
149
150                // ... from our buffer first...
151                let l = cmp::min(self.buffer.len(), chunk_size);
152                inner.write_all(&self.buffer[..l])?;
153                crate::vec_drain_prefix(&mut self.buffer, l);
154
155                // ... then from other.
156                if chunk_size > l {
157                    inner.write_all(&other[..chunk_size - l])?;
158                    other = &other[chunk_size - l..];
159                }
160            }
161
162            self.buffer.extend_from_slice(other);
163            assert!(self.buffer.len() <= self.buffer_threshold);
164        }
165
166        Ok(())
167    }
168}
169
170impl<'a, C: 'a> io::Write for PartialBodyFilter<'a, C> {
171    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
172        // If we can write out a chunk, avoid an extra copy.
173        if buf.len() >= self.buffer_threshold - self.buffer.len() {
174            self.write_out(buf, false)?;
175        } else {
176            self.buffer.append(buf.to_vec().as_mut());
177        }
178        self.position += buf.len() as u64;
179        Ok(buf.len())
180    }
181
182    // XXX: The API says that `flush` is supposed to flush any
183    // internal buffers to disk.  We don't do that.
184    fn flush(&mut self) -> io::Result<()> {
185        self.write_out(&b""[..], false)
186    }
187}
188
189impl<'a, C: 'a> fmt::Debug for PartialBodyFilter<'a, C> {
190    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
191        f.debug_struct("PartialBodyFilter")
192            .field("inner", &self.inner)
193            .finish()
194    }
195}
196
197impl<'a, C: 'a> writer::Stackable<'a, C> for PartialBodyFilter<'a, C> {
198    fn into_inner(mut self: Box<Self>) -> Result<Option<writer::BoxStack<'a, C>>> {
199        self.write_out(&b""[..], true)?;
200        Ok(self.inner.take())
201    }
202    fn pop(&mut self) -> Result<Option<writer::BoxStack<'a, C>>> {
203        self.write_out(&b""[..], true)?;
204        Ok(self.inner.take())
205    }
206    fn mount(&mut self, new: writer::BoxStack<'a, C>) {
207        self.inner = Some(new);
208    }
209    fn inner_mut(&mut self) -> Option<&mut (dyn writer::Stackable<'a, C> + Send + Sync)> {
210        if let Some(ref mut i) = self.inner {
211            Some(i)
212        } else {
213            None
214        }
215    }
216    fn inner_ref(&self) -> Option<&(dyn writer::Stackable<'a, C> + Send + Sync)> {
217        if let Some(ref i) = self.inner {
218            Some(i)
219        } else {
220            None
221        }
222    }
223    fn cookie_set(&mut self, cookie: C) -> C {
224        ::std::mem::replace(&mut self.cookie, cookie)
225    }
226    fn cookie_ref(&self) -> &C {
227        &self.cookie
228    }
229    fn cookie_mut(&mut self) -> &mut C {
230        &mut self.cookie
231    }
232    fn position(&self) -> u64 {
233        self.position
234    }
235}
236
237#[cfg(test)]
238mod test {
239    use std::io::Write;
240    use super::*;
241    use crate::serialize::stream::Message;
242
243    #[test]
244    fn basic() {
245        let mut buf = Vec::new();
246        {
247            let message = Message::new(&mut buf);
248            let mut pb = PartialBodyFilter::with_limits(
249                message, Default::default(),
250                /* buffer_threshold: */ 16,
251                /*   max_chunk_size: */ 16)
252                .unwrap();
253            pb.write_all(b"0123").unwrap();
254            pb.write_all(b"4567").unwrap();
255            pb.finalize().unwrap();
256        }
257        assert_eq!(&buf,
258                   &[8, // no chunking
259                     0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37]);
260    }
261
262    #[test]
263    fn no_avoidable_chunking() {
264        let mut buf = Vec::new();
265        {
266            let message = Message::new(&mut buf);
267            let mut pb = PartialBodyFilter::with_limits(
268                message, Default::default(),
269                /* buffer_threshold: */ 4,
270                /*   max_chunk_size: */ 16)
271                .unwrap();
272            pb.write_all(b"01234567").unwrap();
273            pb.finalize().unwrap();
274        }
275        assert_eq!(&buf,
276                   &[0xe0 + 3, // first chunk
277                     0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
278                     0, // rest
279                   ]);
280    }
281
282    #[test]
283    fn write_exceeding_buffer_threshold() {
284        let mut buf = Vec::new();
285        {
286            let message = Message::new(&mut buf);
287            let mut pb = PartialBodyFilter::with_limits(
288                message, Default::default(),
289                /* buffer_threshold: */ 8,
290                /*   max_chunk_size: */ 16)
291                .unwrap();
292            pb.write_all(b"012345670123456701234567").unwrap();
293            pb.finalize().unwrap();
294        }
295        assert_eq!(&buf,
296                   &[0xe0 + 4, // first chunk
297                     0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
298                     0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
299                     8, // rest
300                     0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37]);
301    }
302}