1mod crypt_writer;
27use crypt_writer::CryptWriter;
28use futures::prelude::*;
29use log::trace;
30use pin_project::pin_project;
31use rand::RngCore;
32use salsa20::{
33 cipher::{NewStreamCipher, SyncStreamCipher},
34 Salsa20, XSalsa20,
35};
36use sha3::{digest::ExtendableOutput, Shake128};
37use std::{
38 error,
39 fmt::{self, Write},
40 io,
41 io::Error as IoError,
42 num::ParseIntError,
43 pin::Pin,
44 str::FromStr,
45 task::{Context, Poll},
46};
47
48const KEY_SIZE: usize = 32;
49const NONCE_SIZE: usize = 24;
50const WRITE_BUFFER_SIZE: usize = 1024;
51const FINGERPRINT_SIZE: usize = 16;
52
53#[derive(Copy, Clone, PartialEq, Eq)]
55pub struct PreSharedKey([u8; KEY_SIZE]);
56
57impl PreSharedKey {
58 pub fn new(data: [u8; KEY_SIZE]) -> Self {
60 Self(data)
61 }
62
63 pub fn fingerprint(&self) -> Fingerprint {
69 use std::io::{Read, Write};
70 let mut enc = [0u8; 64];
71 let nonce: [u8; 8] = *b"finprint";
72 let mut out = [0u8; 16];
73 let mut cipher = Salsa20::new(&self.0.into(), &nonce.into());
74 cipher.apply_keystream(&mut enc);
75 let mut hasher = Shake128::default();
76 hasher.write_all(&enc).expect("shake128 failed");
77 hasher.finalize_xof().read_exact(&mut out).expect("shake128 failed");
78 Fingerprint(out)
79 }
80}
81
82fn parse_hex_key(s: &str) -> Result<[u8; KEY_SIZE], KeyParseError> {
83 if s.len() == KEY_SIZE * 2 {
84 let mut r = [0u8; KEY_SIZE];
85 for i in 0..KEY_SIZE {
86 r[i] = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16)
87 .map_err(KeyParseError::InvalidKeyChar)?;
88 }
89 Ok(r)
90 } else {
91 Err(KeyParseError::InvalidKeyLength)
92 }
93}
94
95fn to_hex(bytes: &[u8]) -> String {
96 let mut hex = String::with_capacity(bytes.len() * 2);
97
98 for byte in bytes {
99 write!(hex, "{:02x}", byte).expect("Can't fail on writing to string");
100 }
101
102 hex
103}
104
105impl FromStr for PreSharedKey {
109 type Err = KeyParseError;
110
111 fn from_str(s: &str) -> Result<Self, Self::Err> {
112 if let [keytype, encoding, key] = *s.lines().take(3).collect::<Vec<_>>().as_slice() {
113 if keytype != "/key/swarm/psk/1.0.0/" {
114 return Err(KeyParseError::InvalidKeyType);
115 }
116 if encoding != "/base16/" {
117 return Err(KeyParseError::InvalidKeyEncoding);
118 }
119 parse_hex_key(key.trim_end()).map(PreSharedKey)
120 } else {
121 Err(KeyParseError::InvalidKeyFile)
122 }
123 }
124}
125
126impl fmt::Debug for PreSharedKey {
127 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128 f.debug_tuple("PreSharedKey")
129 .field(&to_hex(&self.0))
130 .finish()
131 }
132}
133
134impl fmt::Display for PreSharedKey {
136 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137 writeln!(f, "/key/swarm/psk/1.0.0/")?;
138 writeln!(f, "/base16/")?;
139 writeln!(f, "{}", to_hex(&self.0))
140 }
141}
142
143#[derive(Copy, Clone, PartialEq, Eq)]
145pub struct Fingerprint([u8; FINGERPRINT_SIZE]);
146
147impl fmt::Display for Fingerprint {
149 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
150 write!(f, "{}", to_hex(&self.0))
151 }
152}
153
154#[derive(Clone, Debug, PartialEq, Eq)]
156pub enum KeyParseError {
157 InvalidKeyFile,
159 InvalidKeyType,
161 InvalidKeyEncoding,
163 InvalidKeyLength,
165 InvalidKeyChar(ParseIntError),
167}
168
169impl fmt::Display for KeyParseError {
170 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171 write!(f, "{:?}", self)
172 }
173}
174
175impl error::Error for KeyParseError {
176 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
177 match *self {
178 KeyParseError::InvalidKeyChar(ref err) => Some(err),
179 _ => None,
180 }
181 }
182}
183
184#[derive(Debug, Copy, Clone)]
186pub struct PnetConfig {
187 key: PreSharedKey,
189}
190impl PnetConfig {
191 pub fn new(key: PreSharedKey) -> Self {
192 Self { key }
193 }
194
195 pub async fn handshake<TSocket>(
200 self,
201 mut socket: TSocket,
202 ) -> Result<PnetOutput<TSocket>, PnetError>
203 where
204 TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
205 {
206 trace!("exchanging nonces");
207 let mut local_nonce = [0u8; NONCE_SIZE];
208 let mut remote_nonce = [0u8; NONCE_SIZE];
209 rand::thread_rng().fill_bytes(&mut local_nonce);
210 socket
211 .write_all(&local_nonce)
212 .await
213 .map_err(PnetError::HandshakeError)?;
214 socket
215 .read_exact(&mut remote_nonce)
216 .await
217 .map_err(PnetError::HandshakeError)?;
218 trace!("setting up ciphers");
219 let write_cipher = XSalsa20::new(&self.key.0.into(), &local_nonce.into());
220 let read_cipher = XSalsa20::new(&self.key.0.into(), &remote_nonce.into());
221 Ok(PnetOutput::new(socket, write_cipher, read_cipher))
222 }
223}
224
225#[pin_project]
228pub struct PnetOutput<S> {
229 #[pin]
230 inner: CryptWriter<S>,
231 read_cipher: XSalsa20,
232}
233
234impl<S: AsyncRead + AsyncWrite> PnetOutput<S> {
235 fn new(inner: S, write_cipher: XSalsa20, read_cipher: XSalsa20) -> Self {
236 Self {
237 inner: CryptWriter::with_capacity(WRITE_BUFFER_SIZE, inner, write_cipher),
238 read_cipher,
239 }
240 }
241}
242
243impl<S: AsyncRead + AsyncWrite> AsyncRead for PnetOutput<S> {
244 fn poll_read(
245 self: Pin<&mut Self>,
246 cx: &mut Context<'_>,
247 buf: &mut [u8],
248 ) -> Poll<Result<usize, io::Error>> {
249 let this = self.project();
250 let result = this.inner.get_pin_mut().poll_read(cx, buf);
251 if let Poll::Ready(Ok(size)) = &result {
252 trace!("read {} bytes", size);
253 this.read_cipher.apply_keystream(&mut buf[..*size]);
254 trace!("decrypted {} bytes", size);
255 }
256 result
257 }
258}
259
260impl<S: AsyncRead + AsyncWrite> AsyncWrite for PnetOutput<S> {
261 fn poll_write(
262 self: Pin<&mut Self>,
263 cx: &mut Context<'_>,
264 buf: &[u8],
265 ) -> Poll<Result<usize, io::Error>> {
266 self.project().inner.poll_write(cx, buf)
267 }
268
269 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
270 self.project().inner.poll_flush(cx)
271 }
272
273 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
274 self.project().inner.poll_close(cx)
275 }
276}
277
278#[derive(Debug)]
280pub enum PnetError {
281 HandshakeError(IoError),
283 IoError(IoError),
285}
286
287impl From<IoError> for PnetError {
288 #[inline]
289 fn from(err: IoError) -> PnetError {
290 PnetError::IoError(err)
291 }
292}
293
294impl error::Error for PnetError {
295 fn cause(&self) -> Option<&dyn error::Error> {
296 match *self {
297 PnetError::HandshakeError(ref err) => Some(err),
298 PnetError::IoError(ref err) => Some(err),
299 }
300 }
301}
302
303impl fmt::Display for PnetError {
304 #[inline]
305 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
306 match self {
307 PnetError::HandshakeError(e) => write!(f, "Handshake error: {}", e),
308 PnetError::IoError(e) => write!(f, "I/O error: {}", e),
309 }
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use quickcheck::*;
317
318 impl Arbitrary for PreSharedKey {
319 fn arbitrary<G: Gen>(g: &mut G) -> PreSharedKey {
320 let mut key = [0; KEY_SIZE];
321 g.fill_bytes(&mut key);
322 PreSharedKey(key)
323 }
324 }
325
326 #[test]
327 fn psk_tostring_parse() {
328 fn prop(key: PreSharedKey) -> bool {
329 let text = key.to_string();
330 text.parse::<PreSharedKey>()
331 .map(|res| res == key)
332 .unwrap_or(false)
333 }
334 QuickCheck::new()
335 .tests(10)
336 .quickcheck(prop as fn(PreSharedKey) -> _);
337 }
338
339 #[test]
340 fn psk_parse_failure() {
341 use KeyParseError::*;
342 assert_eq!("".parse::<PreSharedKey>().unwrap_err(), InvalidKeyFile);
343 assert_eq!(
344 "a\nb\nc".parse::<PreSharedKey>().unwrap_err(),
345 InvalidKeyType
346 );
347 assert_eq!(
348 "/key/swarm/psk/1.0.0/\nx\ny"
349 .parse::<PreSharedKey>()
350 .unwrap_err(),
351 InvalidKeyEncoding
352 );
353 assert_eq!(
354 "/key/swarm/psk/1.0.0/\n/base16/\ny"
355 .parse::<PreSharedKey>()
356 .unwrap_err(),
357 InvalidKeyLength
358 );
359 }
360
361 #[test]
362 fn fingerprint() {
363 let key = "/key/swarm/psk/1.0.0/\n/base16/\n6189c5cf0b87fb800c1a9feeda73c6ab5e998db48fb9e6a978575c770ceef683".parse::<PreSharedKey>().unwrap();
365 let expected = "45fc986bbc9388a11d939df26f730f0c";
366 let actual = key.fingerprint().to_string();
367 assert_eq!(expected, actual);
368 }
369}