1use crate::msg::*;
2use crate::NonceGen;
3use core::cmp::min;
4use core::pin::Pin;
5use core::task::{Context, Poll};
6use futures_core::ready;
7use futures_io::{AsyncWrite, Error};
8use ssb_crypto::secretbox::{Key, Nonce};
9
10pub const MAX_BOX_SIZE: usize = 4096;
11
12pub(crate) fn seal(mut body: &mut [u8], key: &Key, noncegen: &mut NonceGen) -> Head {
13 let head_nonce = noncegen.next();
14 let body_nonce = noncegen.next();
15
16 let body_hmac = key.seal(&mut body, &body_nonce);
17 HeadPayload::new(body.len() as u16, body_hmac).seal(&key, head_nonce)
18}
19
20pub struct BoxWriter<W, B> {
21 inner: W,
22 buffer: B,
23 state: State,
24 key: Key,
25 nonces: NonceGen,
26}
27
28impl<W, B> BoxWriter<W, B> {
29 pub fn with_buffer(inner: W, key: Key, nonce: Nonce, buffer: B) -> BoxWriter<W, B> {
30 BoxWriter {
31 inner,
32 buffer,
33 state: State::Buffering { pos: 0 },
34 key,
35 nonces: NonceGen::with_starting_nonce(nonce),
36 }
37 }
38
39 pub fn is_closed(&self) -> bool {
40 matches!(self.state, State::Closed)
41 }
42
43 pub fn into_inner(self) -> W {
44 self.inner
45 }
46}
47
48impl<W> BoxWriter<W, Vec<u8>> {
49 pub fn new(w: W, key: Key, nonce: Nonce) -> BoxWriter<W, Vec<u8>> {
50 BoxWriter::with_buffer(w, key, nonce, vec![0; 4096])
51 }
52}
53
54enum State {
55 Buffering {
56 pos: usize,
57 },
58 SendingHead {
59 head: Head,
60 pos: usize,
61 body_size: usize,
62 },
63 SendingBody {
64 body_size: usize,
65 pos: usize,
66 },
67 SendingGoodbye {
68 head: Head,
69 pos: usize,
70 },
71 Closed,
72}
73
74impl<W, B> AsyncWrite for BoxWriter<W, B>
75where
76 W: AsyncWrite + Unpin + 'static,
77 B: AsMut<[u8]> + Unpin,
78{
79 fn poll_write(
80 self: Pin<&mut Self>,
81 cx: &mut Context,
82 mut to_write: &[u8],
83 ) -> Poll<Result<usize, Error>> {
84 let mut this = self.get_mut();
85 let mut wrote_bytes = 0;
86
87 loop {
88 match this.state {
89 State::Buffering { pos } => {
90 let buffer = this.buffer.as_mut();
91 let n = min(buffer.len() - pos, to_write.len());
92
93 let (b, rest) = to_write.split_at(n);
94 buffer[pos..pos + n].copy_from_slice(b);
95
96 wrote_bytes += n;
97 to_write = rest;
98
99 if pos + n == buffer.len() {
100 let head = seal(buffer, &this.key, &mut this.nonces);
101 this.state = State::SendingHead {
102 head,
103 pos: 0,
104 body_size: buffer.len(),
105 };
106 } else {
107 this.state = State::Buffering { pos: pos + n };
108 return Poll::Ready(Ok(wrote_bytes));
109 }
110 }
111
112 State::SendingHead {
113 head,
114 pos,
115 body_size,
116 } => {
117 let hb = head.as_bytes();
118 let n = ready!(Pin::new(&mut this.inner).poll_write(cx, &hb[pos..]))?;
119 if pos + n == hb.len() {
120 this.state = State::SendingBody { body_size, pos: 0 };
121 } else {
122 this.state = State::SendingHead {
123 head,
124 pos: pos + n,
125 body_size,
126 };
127 return Poll::Pending;
128 }
129 }
130
131 State::SendingBody { body_size, pos } => {
132 let n = ready!(Pin::new(&mut this.inner)
133 .poll_write(cx, &this.buffer.as_mut()[pos..body_size]))?;
134 if pos + n == body_size {
135 this.state = State::Buffering { pos: 0 };
136 } else {
137 this.state = State::SendingBody {
138 body_size,
139 pos: pos + n,
140 };
141 return Poll::Pending;
142 }
143 }
144
145 State::SendingGoodbye { .. } => panic!(), State::Closed => return Poll::Ready(Ok(0)),
147 }
148 }
149 }
150
151 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Error>> {
152 let mut this = self.get_mut();
153 match this.state {
154 State::Buffering { pos } => {
155 if pos == 0 {
156 Pin::new(&mut this.inner).poll_flush(cx)
157 } else {
158 let mut body = &mut this.buffer.as_mut()[..pos];
159 let head = seal(&mut body, &this.key, &mut this.nonces);
160 this.state = State::SendingHead {
161 head,
162 pos: 0,
163 body_size: pos,
164 };
165 Pin::new(this).poll_flush(cx)
166 }
167 }
168
169 State::SendingHead {
170 head,
171 pos,
172 body_size,
173 } => {
174 let bytes = head.as_bytes();
175
176 let n = ready!(Pin::new(&mut this.inner).poll_write(cx, &bytes[pos..]))?;
177 if pos + n == bytes.len() {
178 this.state = State::SendingBody { body_size, pos: 0 };
179 Pin::new(this).poll_flush(cx)
180 } else {
181 this.state = State::SendingHead {
182 head,
183 pos: pos + n,
184 body_size,
185 };
186 Poll::Pending
187 }
188 }
189
190 State::SendingBody { body_size, pos } => {
191 let n =
192 ready!(Pin::new(&mut this.inner)
193 .poll_write(cx, &this.buffer.as_mut()[pos..body_size]))?;
194 if pos + n == body_size {
195 this.state = State::Buffering { pos: 0 };
196 Pin::new(&mut this.inner).poll_flush(cx)
197 } else {
198 this.state = State::SendingBody {
199 body_size,
200 pos: pos + n,
201 };
202 Poll::Pending
203 }
204 }
205 State::SendingGoodbye { .. } => panic!(),
206 State::Closed => Poll::Ready(Ok(())),
207 }
208 }
209
210 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Error>> {
211 let mut this = self.get_mut();
212 match this.state {
213 State::SendingGoodbye { head, pos } => {
214 let bytes = head.as_bytes();
215
216 let n = ready!(Pin::new(&mut this.inner).poll_write(cx, &bytes[pos..]))?;
217 if pos + n == bytes.len() {
218 this.state = State::Closed;
219 Pin::new(&mut this.inner).poll_close(cx)
220 } else {
221 this.state = State::SendingGoodbye { head, pos: pos + n };
222 Poll::Pending
223 }
224 }
225
226 _ => {
227 ready!(Pin::new(&mut this).poll_flush(cx))?;
228 let head = HeadPayload::goodbye().seal(&this.key, this.nonces.next());
229 this.state = State::SendingGoodbye { head, pos: 0 };
230 Pin::new(&mut this).poll_close(cx)
231 }
232 }
233 }
234}