1use crate::error::ViiperError;
4use chacha20poly1305::{
5 aead::{Aead, KeyInit},
6 ChaCha20Poly1305, Nonce,
7};
8use hmac::{Hmac, Mac};
9use pbkdf2::pbkdf2_hmac;
10use rand::RngCore;
11use sha2::{Digest, Sha256};
12use std::io::{Read, Write};
13use std::net::TcpStream;
14
15#[cfg(feature = "async")]
16use std::pin::Pin;
17#[cfg(feature = "async")]
18use std::task::{Context, Poll};
19#[cfg(feature = "async")]
20use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
21#[cfg(feature = "async")]
22use tokio::net::TcpStream as AsyncTcpStream;
23
24const HANDSHAKE_MAGIC: &[u8] = b"eVI1\x00";
25const NONCE_SIZE: usize = 32;
26const AUTH_CONTEXT: &[u8] = b"VIIPER-Auth-v1";
27const SESSION_CONTEXT: &[u8] = b"VIIPER-Session-v1";
28const PBKDF2_ITERATIONS: u32 = 100_000;
29const PBKDF2_SALT: &[u8] = b"VIIPER-Key-v1";
30
31fn derive_key(password: &str) -> Result<[u8; 32], ViiperError> {
33 if password.is_empty() {
34 return Err(ViiperError::UnexpectedResponse("Password cannot be empty".into()));
35 }
36 let mut key = [0u8; 32];
37 pbkdf2_hmac::<Sha256>(password.as_bytes(), PBKDF2_SALT, PBKDF2_ITERATIONS, &mut key);
38 Ok(key)
39}
40
41fn derive_session_key(key: &[u8], server_nonce: &[u8], client_nonce: &[u8]) -> [u8; 32] {
43 let mut hasher = Sha256::new();
44 hasher.update(key);
45 hasher.update(server_nonce);
46 hasher.update(client_nonce);
47 hasher.update(SESSION_CONTEXT);
48 hasher.finalize().into()
49}
50
51pub fn perform_handshake(mut stream: TcpStream, password: &str) -> Result<EncryptedStream, ViiperError> {
53 let key = derive_key(password)?;
54 let mut client_nonce = [0u8; NONCE_SIZE];
55 rand::thread_rng().fill_bytes(&mut client_nonce);
56
57 let mut mac = <Hmac::<Sha256> as KeyInit>::new_from_slice(&key)
58 .map_err(|_| ViiperError::UnexpectedResponse("Invalid key length".into()))?;
59 mac.update(AUTH_CONTEXT);
60 mac.update(&client_nonce);
61 let auth_tag = mac.finalize().into_bytes();
62
63 let mut handshake_msg = Vec::with_capacity(HANDSHAKE_MAGIC.len() + NONCE_SIZE + 32);
64 handshake_msg.extend_from_slice(HANDSHAKE_MAGIC);
65 handshake_msg.extend_from_slice(&client_nonce);
66 handshake_msg.extend_from_slice(&auth_tag);
67
68 stream.write_all(&handshake_msg)?;
69
70 let mut response = vec![0u8; 3 + NONCE_SIZE];
71 stream.read_exact(&mut response)?;
72
73 if &response[0..3] != b"OK\x00" {
74 let mut error_buf = Vec::new();
75 let _ = stream.read_to_end(&mut error_buf);
76 let full_response = [response, error_buf].concat();
77 let error_str = String::from_utf8_lossy(&full_response);
78
79 if let Ok(problem) = serde_json::from_str::<crate::error::ProblemJson>(&error_str) {
80 return Err(ViiperError::Protocol(problem));
81 }
82 return Err(ViiperError::UnexpectedResponse(format!("Invalid handshake response: {}", error_str)));
83 }
84
85 let server_nonce = &response[3..];
86
87 let session_key = derive_session_key(&key, server_nonce, &client_nonce);
88
89 Ok(EncryptedStream::new(stream, session_key)?)
90}
91
92#[cfg(feature = "async")]
94pub async fn perform_handshake_async(mut stream: AsyncTcpStream, password: &str) -> Result<AsyncEncryptedStream, ViiperError> {
95 let key = derive_key(password)?;
96
97 let mut client_nonce = [0u8; NONCE_SIZE];
98 rand::thread_rng().fill_bytes(&mut client_nonce);
99
100 let mut mac = <Hmac::<Sha256> as KeyInit>::new_from_slice(&key)
101 .map_err(|_| ViiperError::UnexpectedResponse("Invalid key length".into()))?;
102 mac.update(AUTH_CONTEXT);
103 mac.update(&client_nonce);
104 let auth_tag = mac.finalize().into_bytes();
105
106 let mut handshake_msg = Vec::with_capacity(HANDSHAKE_MAGIC.len() + NONCE_SIZE + 32);
107 handshake_msg.extend_from_slice(HANDSHAKE_MAGIC);
108 handshake_msg.extend_from_slice(&client_nonce);
109 handshake_msg.extend_from_slice(&auth_tag);
110
111 stream.write_all(&handshake_msg).await?;
112
113 let mut response = vec![0u8; 3 + NONCE_SIZE];
114 stream.read_exact(&mut response).await?;
115
116 if &response[0..3] != b"OK\x00" {
117 let mut error_buf = Vec::new();
118 let _ = stream.read_to_end(&mut error_buf).await;
119 let full_response = [response, error_buf].concat();
120 let error_str = String::from_utf8_lossy(&full_response);
121
122 if let Ok(problem) = serde_json::from_str::<crate::error::ProblemJson>(&error_str) {
123 return Err(ViiperError::Protocol(problem));
124 }
125 return Err(ViiperError::UnexpectedResponse(format!("Invalid handshake response: {}", error_str)));
126 }
127
128 let server_nonce = &response[3..];
129
130 let session_key = derive_session_key(&key, server_nonce, &client_nonce);
131
132 Ok(AsyncEncryptedStream::new(stream, session_key))
133}
134
135pub struct EncryptedStream {
139 read: std::sync::Arc<std::sync::Mutex<EncryptedReadState>>,
140 write: std::sync::Arc<std::sync::Mutex<EncryptedWriteState>>,
141}
142
143struct EncryptedReadState {
144 stream: TcpStream,
145 cipher: ChaCha20Poly1305,
146 recv_buffer: Vec<u8>,
147}
148
149struct EncryptedWriteState {
150 stream: TcpStream,
151 cipher: ChaCha20Poly1305,
152 send_counter: u64,
153}
154
155impl EncryptedStream {
156 fn new(inner: TcpStream, session_key: [u8; 32]) -> Result<Self, ViiperError> {
157 let read_stream = inner.try_clone()?;
158 let read_cipher = ChaCha20Poly1305::new(&session_key.into());
159 let write_cipher = ChaCha20Poly1305::new(&session_key.into());
160 Ok(Self {
161 read: std::sync::Arc::new(std::sync::Mutex::new(EncryptedReadState {
162 stream: read_stream,
163 cipher: read_cipher,
164 recv_buffer: Vec::new(),
165 })),
166 write: std::sync::Arc::new(std::sync::Mutex::new(EncryptedWriteState {
167 stream: inner,
168 cipher: write_cipher,
169 send_counter: 0,
170 })),
171 })
172 }
173
174 pub fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
175 let read = self.read.lock().unwrap();
176 let write = self.write.lock().unwrap();
177 read.stream.set_nodelay(nodelay)?;
178 write.stream.set_nodelay(nodelay)
179 }
180
181 pub fn try_clone(&self) -> std::io::Result<Self> {
182 Ok(Self {
183 read: std::sync::Arc::clone(&self.read),
184 write: std::sync::Arc::clone(&self.write),
185 })
186 }
187
188 pub fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> {
189 let read = self.read.lock().unwrap();
190 let write = self.write.lock().unwrap();
191 let _ = read.stream.shutdown(how);
192 write.stream.shutdown(how)
193 }
194}
195
196impl Read for EncryptedStream {
197 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
198 let mut inner = self.read.lock().unwrap();
199
200 if inner.recv_buffer.is_empty() {
201 let mut first_byte = [0u8; 1];
202 let n = inner.stream.read(&mut first_byte)?;
203 if n == 0 {
204 return Ok(0);
205 }
206
207 let mut len_buf = [0u8; 4];
208 len_buf[0] = first_byte[0];
209 inner.stream.read_exact(&mut len_buf[1..])?;
210 let packet_len = u32::from_be_bytes(len_buf) as usize;
211
212 if packet_len > 2 * 1024 * 1024 {
213 return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Packet too large"));
214 }
215
216 let mut packet = vec![0u8; packet_len];
217 inner.stream.read_exact(&mut packet)?;
218
219 let nonce = Nonce::from_slice(&packet[0..12]);
220 let ciphertext_and_tag = &packet[12..];
221
222 let plaintext = inner.cipher.decrypt(nonce, ciphertext_and_tag)
223 .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "Decryption failed"))?;
224
225 inner.recv_buffer = plaintext;
226 }
227
228 let to_copy = buf.len().min(inner.recv_buffer.len());
229 buf[..to_copy].copy_from_slice(&inner.recv_buffer[..to_copy]);
230 inner.recv_buffer.drain(..to_copy);
231 Ok(to_copy)
232 }
233}
234
235impl Write for EncryptedStream {
236 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
237 let mut inner = self.write.lock().unwrap();
238
239 let mut nonce_bytes = [0u8; 12];
240 nonce_bytes[4..].copy_from_slice(&inner.send_counter.to_be_bytes());
241 inner.send_counter += 1;
242 let nonce = Nonce::from_slice(&nonce_bytes);
243
244 let ciphertext = inner.cipher.encrypt(nonce, buf)
245 .map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "Encryption failed"))?;
246
247 let packet = [&nonce_bytes[..], ciphertext.as_slice()].concat();
248 let len_buf = (packet.len() as u32).to_be_bytes();
249
250 inner.stream.write_all(&len_buf)?;
251 inner.stream.write_all(&packet)?;
252
253 Ok(buf.len())
254 }
255
256 fn flush(&mut self) -> std::io::Result<()> {
257 let mut inner = self.write.lock().unwrap();
258 inner.stream.flush()
259 }
260}
261
262#[cfg(feature = "async")]
264pub struct AsyncEncryptedStream {
265 read: AsyncEncryptedRead,
266 write: AsyncEncryptedWrite,
267}
268
269#[cfg(feature = "async")]
270pub struct AsyncEncryptedRead {
271 inner: tokio::net::tcp::OwnedReadHalf,
272 cipher: ChaCha20Poly1305,
273 recv_buffer: Vec<u8>,
274 read_state: ReadState,
275}
276
277#[cfg(feature = "async")]
278pub struct AsyncEncryptedWrite {
279 inner: tokio::net::tcp::OwnedWriteHalf,
280 cipher: ChaCha20Poly1305,
281 send_counter: u64,
282}
283
284#[cfg(feature = "async")]
285enum ReadState {
286 ReadingLength { buf: [u8; 4], pos: usize },
287 ReadingPacket { expected_len: usize, buf: Vec<u8>, pos: usize },
288 Ready,
289}
290
291#[cfg(feature = "async")]
292impl AsyncEncryptedStream {
293 fn new(inner: AsyncTcpStream, session_key: [u8; 32]) -> Self {
294 let (read_half, write_half) = inner.into_split();
295 let read_cipher = ChaCha20Poly1305::new(&session_key.into());
296 let write_cipher = ChaCha20Poly1305::new(&session_key.into());
297 Self {
298 read: AsyncEncryptedRead {
299 inner: read_half,
300 cipher: read_cipher,
301 recv_buffer: Vec::new(),
302 read_state: ReadState::ReadingLength { buf: [0; 4], pos: 0 },
303 },
304 write: AsyncEncryptedWrite {
305 inner: write_half,
306 cipher: write_cipher,
307 send_counter: 0,
308 },
309 }
310 }
311
312 pub fn into_split(self) -> (AsyncEncryptedRead, AsyncEncryptedWrite) {
313 (self.read, self.write)
314 }
315}
316
317#[cfg(feature = "async")]
318impl AsyncRead for AsyncEncryptedRead {
319 fn poll_read(
320 mut self: Pin<&mut Self>,
321 cx: &mut Context<'_>,
322 buf: &mut ReadBuf<'_>,
323 ) -> Poll<std::io::Result<()>> {
324 if !self.recv_buffer.is_empty() {
325 let to_copy = buf.remaining().min(self.recv_buffer.len());
326 buf.put_slice(&self.recv_buffer[..to_copy]);
327 self.recv_buffer.drain(..to_copy);
328 return Poll::Ready(Ok(()));
329 }
330
331 loop {
332 let state = std::mem::replace(&mut self.read_state, ReadState::Ready);
333 match state {
334 ReadState::ReadingLength { buf: mut len_buf, pos } => {
335 let mut read_buf = ReadBuf::new(&mut len_buf[pos..]);
336
337 match Pin::new(&mut self.inner).poll_read(cx, &mut read_buf) {
338 Poll::Ready(Ok(())) => {
339 let bytes_read = read_buf.filled().len();
340 if bytes_read == 0 {
341 if pos == 0 {
342 return Poll::Ready(Ok(())); } else {
344 return Poll::Ready(Err(std::io::Error::new(
345 std::io::ErrorKind::UnexpectedEof,
346 "Connection closed while reading length"
347 )));
348 }
349 }
350 let new_pos = pos + bytes_read;
351 if new_pos < 4 {
352 self.read_state = ReadState::ReadingLength { buf: len_buf, pos: new_pos };
353 } else {
354 let packet_len = u32::from_be_bytes(len_buf) as usize;
356 if packet_len > 2 * 1024 * 1024 {
357 return Poll::Ready(Err(std::io::Error::new(
358 std::io::ErrorKind::InvalidData,
359 "Packet too large"
360 )));
361 }
362 self.read_state = ReadState::ReadingPacket {
363 expected_len: packet_len,
364 buf: vec![0u8; packet_len],
365 pos: 0,
366 };
367 }
368 }
369 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
370 Poll::Pending => {
371 self.read_state = ReadState::ReadingLength { buf: len_buf, pos };
372 return Poll::Pending;
373 }
374 }
375 }
376 ReadState::ReadingPacket { expected_len, buf: mut packet_buf, pos } => {
377 let mut read_buf = ReadBuf::new(&mut packet_buf[pos..]);
378
379 match Pin::new(&mut self.inner).poll_read(cx, &mut read_buf) {
380 Poll::Ready(Ok(())) => {
381 let bytes_read = read_buf.filled().len();
382 if bytes_read == 0 {
383 return Poll::Ready(Err(std::io::Error::new(
384 std::io::ErrorKind::UnexpectedEof,
385 "Connection closed while reading packet"
386 )));
387 }
388 let new_pos = pos + bytes_read;
389 if new_pos < expected_len {
390 self.read_state = ReadState::ReadingPacket {
391 expected_len,
392 buf: packet_buf,
393 pos: new_pos,
394 };
395 } else {
396 let nonce = Nonce::from_slice(&packet_buf[0..12]);
397 let ciphertext_and_tag = &packet_buf[12..];
398
399 match self.cipher.decrypt(nonce, ciphertext_and_tag) {
400 Ok(plaintext) => {
401 self.recv_buffer = plaintext;
402 self.read_state = ReadState::ReadingLength { buf: [0; 4], pos: 0 };
403
404 let to_copy = buf.remaining().min(self.recv_buffer.len());
405 buf.put_slice(&self.recv_buffer[..to_copy]);
406 self.recv_buffer.drain(..to_copy);
407 return Poll::Ready(Ok(()));
408 }
409 Err(_) => {
410 return Poll::Ready(Err(std::io::Error::new(
411 std::io::ErrorKind::InvalidData,
412 "Decryption failed"
413 )));
414 }
415 }
416 }
417 }
418 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
419 Poll::Pending => {
420 self.read_state = ReadState::ReadingPacket {
421 expected_len,
422 buf: packet_buf,
423 pos,
424 };
425 return Poll::Pending;
426 }
427 }
428 }
429 ReadState::Ready => {
430 self.read_state = ReadState::ReadingLength { buf: [0; 4], pos: 0 };
431 }
432 }
433 }
434 }
435}
436
437#[cfg(feature = "async")]
438impl AsyncWrite for AsyncEncryptedWrite {
439 fn poll_write(
440 mut self: Pin<&mut Self>,
441 cx: &mut Context<'_>,
442 buf: &[u8],
443 ) -> Poll<Result<usize, std::io::Error>> {
444 let mut nonce_bytes = [0u8; 12];
445 nonce_bytes[4..].copy_from_slice(&self.send_counter.to_be_bytes());
446 self.send_counter += 1;
447 let nonce = Nonce::from_slice(&nonce_bytes);
448
449 let ciphertext = self.cipher.encrypt(nonce, buf)
450 .map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "Encryption failed"))?;
451
452 let packet = [&nonce_bytes[..], ciphertext.as_slice()].concat();
453 let len_buf = (packet.len() as u32).to_be_bytes();
454
455 let full_packet = [&len_buf[..], &packet].concat();
456
457 match Pin::new(&mut self.inner).poll_write(cx, &full_packet) {
458 Poll::Ready(Ok(n)) if n >= full_packet.len() => Poll::Ready(Ok(buf.len())),
459 Poll::Ready(Ok(_)) => Poll::Ready(Err(std::io::Error::new(
460 std::io::ErrorKind::WriteZero,
461 "Failed to write complete packet"
462 ))),
463 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
464 Poll::Pending => Poll::Pending,
465 }
466 }
467
468 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
469 Pin::new(&mut self.inner).poll_flush(cx)
470 }
471
472 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
473 Pin::new(&mut self.inner).poll_shutdown(cx)
474 }
475}
476
477#[cfg(feature = "async")]
478impl AsyncRead for AsyncEncryptedStream {
479 fn poll_read(
480 mut self: Pin<&mut Self>,
481 cx: &mut Context<'_>,
482 buf: &mut ReadBuf<'_>,
483 ) -> Poll<std::io::Result<()>> {
484 Pin::new(&mut self.read).poll_read(cx, buf)
485 }
486}
487
488#[cfg(feature = "async")]
489impl AsyncWrite for AsyncEncryptedStream {
490 fn poll_write(
491 mut self: Pin<&mut Self>,
492 cx: &mut Context<'_>,
493 buf: &[u8],
494 ) -> Poll<Result<usize, std::io::Error>> {
495 Pin::new(&mut self.write).poll_write(cx, buf)
496 }
497
498 fn poll_flush(
499 mut self: Pin<&mut Self>,
500 cx: &mut Context<'_>,
501 ) -> Poll<Result<(), std::io::Error>> {
502 Pin::new(&mut self.write).poll_flush(cx)
503 }
504
505 fn poll_shutdown(
506 mut self: Pin<&mut Self>,
507 cx: &mut Context<'_>,
508 ) -> Poll<Result<(), std::io::Error>> {
509 Pin::new(&mut self.write).poll_shutdown(cx)
510 }
511}