ssb_boxstream/
write.rs

1use crate::msg::*;
2use crate::NonceGen;
3use core::cmp::min;
4use core::pin::Pin;
5use core::task::{Context, Poll};
6use futures_core::ready;
7use futures_io::{AsyncWrite, Error};
8use ssb_crypto::secretbox::{Key, Nonce};
9
10pub const MAX_BOX_SIZE: usize = 4096;
11
12pub(crate) fn seal(mut body: &mut [u8], key: &Key, noncegen: &mut NonceGen) -> Head {
13    let head_nonce = noncegen.next();
14    let body_nonce = noncegen.next();
15
16    let body_hmac = key.seal(&mut body, &body_nonce);
17    HeadPayload::new(body.len() as u16, body_hmac).seal(&key, head_nonce)
18}
19
20pub struct BoxWriter<W, B> {
21    inner: W,
22    buffer: B,
23    state: State,
24    key: Key,
25    nonces: NonceGen,
26}
27
28impl<W, B> BoxWriter<W, B> {
29    pub fn with_buffer(inner: W, key: Key, nonce: Nonce, buffer: B) -> BoxWriter<W, B> {
30        BoxWriter {
31            inner,
32            buffer,
33            state: State::Buffering { pos: 0 },
34            key,
35            nonces: NonceGen::with_starting_nonce(nonce),
36        }
37    }
38
39    pub fn is_closed(&self) -> bool {
40        matches!(self.state, State::Closed)
41    }
42
43    pub fn into_inner(self) -> W {
44        self.inner
45    }
46}
47
48impl<W> BoxWriter<W, Vec<u8>> {
49    pub fn new(w: W, key: Key, nonce: Nonce) -> BoxWriter<W, Vec<u8>> {
50        BoxWriter::with_buffer(w, key, nonce, vec![0; 4096])
51    }
52}
53
54enum State {
55    Buffering {
56        pos: usize,
57    },
58    SendingHead {
59        head: Head,
60        pos: usize,
61        body_size: usize,
62    },
63    SendingBody {
64        body_size: usize,
65        pos: usize,
66    },
67    SendingGoodbye {
68        head: Head,
69        pos: usize,
70    },
71    Closed,
72}
73
74impl<W, B> AsyncWrite for BoxWriter<W, B>
75where
76    W: AsyncWrite + Unpin + 'static,
77    B: AsMut<[u8]> + Unpin,
78{
79    fn poll_write(
80        self: Pin<&mut Self>,
81        cx: &mut Context,
82        mut to_write: &[u8],
83    ) -> Poll<Result<usize, Error>> {
84        let mut this = self.get_mut();
85        let mut wrote_bytes = 0;
86
87        loop {
88            match this.state {
89                State::Buffering { pos } => {
90                    let buffer = this.buffer.as_mut();
91                    let n = min(buffer.len() - pos, to_write.len());
92
93                    let (b, rest) = to_write.split_at(n);
94                    buffer[pos..pos + n].copy_from_slice(b);
95
96                    wrote_bytes += n;
97                    to_write = rest;
98
99                    if pos + n == buffer.len() {
100                        let head = seal(buffer, &this.key, &mut this.nonces);
101                        this.state = State::SendingHead {
102                            head,
103                            pos: 0,
104                            body_size: buffer.len(),
105                        };
106                    } else {
107                        this.state = State::Buffering { pos: pos + n };
108                        return Poll::Ready(Ok(wrote_bytes));
109                    }
110                }
111
112                State::SendingHead {
113                    head,
114                    pos,
115                    body_size,
116                } => {
117                    let hb = head.as_bytes();
118                    let n = ready!(Pin::new(&mut this.inner).poll_write(cx, &hb[pos..]))?;
119                    if pos + n == hb.len() {
120                        this.state = State::SendingBody { body_size, pos: 0 };
121                    } else {
122                        this.state = State::SendingHead {
123                            head,
124                            pos: pos + n,
125                            body_size,
126                        };
127                        return Poll::Pending;
128                    }
129                }
130
131                State::SendingBody { body_size, pos } => {
132                    let n = ready!(Pin::new(&mut this.inner)
133                        .poll_write(cx, &this.buffer.as_mut()[pos..body_size]))?;
134                    if pos + n == body_size {
135                        this.state = State::Buffering { pos: 0 };
136                    } else {
137                        this.state = State::SendingBody {
138                            body_size,
139                            pos: pos + n,
140                        };
141                        return Poll::Pending;
142                    }
143                }
144
145                State::SendingGoodbye { .. } => panic!(), // ??
146                State::Closed => return Poll::Ready(Ok(0)),
147            }
148        }
149    }
150
151    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Error>> {
152        let mut this = self.get_mut();
153        match this.state {
154            State::Buffering { pos } => {
155                if pos == 0 {
156                    Pin::new(&mut this.inner).poll_flush(cx)
157                } else {
158                    let mut body = &mut this.buffer.as_mut()[..pos];
159                    let head = seal(&mut body, &this.key, &mut this.nonces);
160                    this.state = State::SendingHead {
161                        head,
162                        pos: 0,
163                        body_size: pos,
164                    };
165                    Pin::new(this).poll_flush(cx)
166                }
167            }
168
169            State::SendingHead {
170                head,
171                pos,
172                body_size,
173            } => {
174                let bytes = head.as_bytes();
175
176                let n = ready!(Pin::new(&mut this.inner).poll_write(cx, &bytes[pos..]))?;
177                if pos + n == bytes.len() {
178                    this.state = State::SendingBody { body_size, pos: 0 };
179                    Pin::new(this).poll_flush(cx)
180                } else {
181                    this.state = State::SendingHead {
182                        head,
183                        pos: pos + n,
184                        body_size,
185                    };
186                    Poll::Pending
187                }
188            }
189
190            State::SendingBody { body_size, pos } => {
191                let n =
192                    ready!(Pin::new(&mut this.inner)
193                        .poll_write(cx, &this.buffer.as_mut()[pos..body_size]))?;
194                if pos + n == body_size {
195                    this.state = State::Buffering { pos: 0 };
196                    Pin::new(&mut this.inner).poll_flush(cx)
197                } else {
198                    this.state = State::SendingBody {
199                        body_size,
200                        pos: pos + n,
201                    };
202                    Poll::Pending
203                }
204            }
205            State::SendingGoodbye { .. } => panic!(),
206            State::Closed => Poll::Ready(Ok(())),
207        }
208    }
209
210    fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Error>> {
211        let mut this = self.get_mut();
212        match this.state {
213            State::SendingGoodbye { head, pos } => {
214                let bytes = head.as_bytes();
215
216                let n = ready!(Pin::new(&mut this.inner).poll_write(cx, &bytes[pos..]))?;
217                if pos + n == bytes.len() {
218                    this.state = State::Closed;
219                    Pin::new(&mut this.inner).poll_close(cx)
220                } else {
221                    this.state = State::SendingGoodbye { head, pos: pos + n };
222                    Poll::Pending
223                }
224            }
225
226            _ => {
227                ready!(Pin::new(&mut this).poll_flush(cx))?;
228                let head = HeadPayload::goodbye().seal(&this.key, this.nonces.next());
229                this.state = State::SendingGoodbye { head, pos: 0 };
230                Pin::new(&mut this).poll_close(cx)
231            }
232        }
233    }
234}