1#![allow(clippy::cast_possible_truncation)]
29
30use ring::{hmac, rand::{SecureRandom, SystemRandom}};
31use std::io::{Read, Write};
32use subtle::ConstantTimeEq;
33
34use crate::error::{Result, ShieldError};
35use crate::exchange::PAKEExchange;
36use crate::ratchet::RatchetSession;
37
38const PROTOCOL_VERSION: u8 = 1;
40
41const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
43
44#[repr(u8)]
46#[derive(Clone, Copy)]
47enum HandshakeType {
48 ClientHello = 1,
49 ServerHello = 2,
50 Finished = 3,
51}
52
53#[derive(Clone)]
55pub struct ChannelConfig {
56 password: String,
58 service: String,
60 iterations: u32,
62 handshake_timeout_ms: u64,
64}
65
66impl ChannelConfig {
67 #[must_use]
73 pub fn new(password: &str, service: &str) -> Self {
74 Self {
75 password: password.to_string(),
76 service: service.to_string(),
77 iterations: PAKEExchange::ITERATIONS,
78 handshake_timeout_ms: 30_000,
79 }
80 }
81
82 #[must_use]
84 pub fn with_iterations(mut self, iterations: u32) -> Self {
85 self.iterations = iterations;
86 self
87 }
88
89 #[must_use]
91 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
92 self.handshake_timeout_ms = timeout_ms;
93 self
94 }
95
96 #[cfg(feature = "async")]
98 #[must_use]
99 pub(crate) fn password(&self) -> &str {
100 &self.password
101 }
102
103 #[must_use]
105 pub fn service(&self) -> &str {
106 &self.service
107 }
108
109 #[cfg(feature = "async")]
111 #[must_use]
112 pub(crate) fn iterations(&self) -> u32 {
113 self.iterations
114 }
115}
116
117struct HandshakeState {
119 salt: [u8; 16],
120 local_contribution: [u8; 32],
121 remote_contribution: Option<[u8; 32]>,
122 is_initiator: bool,
123}
124
125impl HandshakeState {
126 fn new(is_initiator: bool) -> Result<Self> {
127 let rng = SystemRandom::new();
128
129 let mut salt = [0u8; 16];
130 rng.fill(&mut salt).map_err(|_| ShieldError::RandomFailed)?;
131
132 Ok(Self {
133 salt,
134 local_contribution: [0u8; 32],
135 remote_contribution: None,
136 is_initiator,
137 })
138 }
139
140 fn derive_contribution(&mut self, config: &ChannelConfig) {
141 let role = if self.is_initiator { "client" } else { "server" };
142 self.local_contribution = PAKEExchange::derive(
143 &config.password,
144 &self.salt,
145 role,
146 Some(config.iterations),
147 );
148 }
149
150 fn compute_session_key(&self, config: &ChannelConfig) -> Result<[u8; 32]> {
151 let remote = self.remote_contribution
152 .ok_or_else(|| ShieldError::ChannelError("handshake incomplete".into()))?;
153
154 let base_key = PAKEExchange::combine(&[self.local_contribution, remote]);
158
159 let password_key = PAKEExchange::derive(
161 &config.password,
162 &self.salt,
163 "session",
164 Some(config.iterations),
165 );
166
167 let mut combined = Vec::with_capacity(64);
169 combined.extend_from_slice(&base_key);
170 combined.extend_from_slice(&password_key);
171
172 let hash = ring::digest::digest(&ring::digest::SHA256, &combined);
173 let mut result = [0u8; 32];
174 result.copy_from_slice(hash.as_ref());
175 Ok(result)
176 }
177}
178
179pub struct ShieldChannel<S> {
186 stream: S,
187 session: RatchetSession,
188 service: String,
189}
190
191impl<S: Read + Write> ShieldChannel<S> {
192 pub fn connect(mut stream: S, config: &ChannelConfig) -> Result<Self> {
200 let mut state = HandshakeState::new(true)?;
201
202 Self::send_handshake(&mut stream, HandshakeType::ClientHello, &state.salt)?;
204
205 let server_hello = Self::recv_handshake(&mut stream, HandshakeType::ServerHello)?;
207 if server_hello.len() != 48 {
208 return Err(ShieldError::ChannelError("invalid ServerHello".into()));
209 }
210
211 state.salt.copy_from_slice(&server_hello[..16]);
213 state.derive_contribution(config);
214
215 let mut remote = [0u8; 32];
216 remote.copy_from_slice(&server_hello[16..48]);
217 state.remote_contribution = Some(remote);
218
219 Self::send_handshake(&mut stream, HandshakeType::Finished, &state.local_contribution)?;
221
222 let session_key = state.compute_session_key(config)?;
224 let session = RatchetSession::new(&session_key, true);
225
226 Self::send_confirmation(&mut stream, &session_key, true)?;
228 Self::verify_confirmation(&mut stream, &session_key, false)?;
229
230 Ok(Self {
231 stream,
232 session,
233 service: config.service.clone(),
234 })
235 }
236
237 pub fn accept(mut stream: S, config: &ChannelConfig) -> Result<Self> {
245 let mut state = HandshakeState::new(false)?;
246
247 let client_hello = Self::recv_handshake(&mut stream, HandshakeType::ClientHello)?;
249 if client_hello.len() != 16 {
250 return Err(ShieldError::ChannelError("invalid ClientHello".into()));
251 }
252
253 for (i, &b) in client_hello.iter().enumerate() {
255 state.salt[i] ^= b;
256 }
257
258 state.derive_contribution(config);
259
260 let mut server_hello = Vec::with_capacity(48);
262 server_hello.extend_from_slice(&state.salt);
263 server_hello.extend_from_slice(&state.local_contribution);
264 Self::send_handshake(&mut stream, HandshakeType::ServerHello, &server_hello)?;
265
266 let client_finished = Self::recv_handshake(&mut stream, HandshakeType::Finished)?;
268 if client_finished.len() != 32 {
269 return Err(ShieldError::ChannelError("invalid Finished".into()));
270 }
271
272 let mut remote = [0u8; 32];
273 remote.copy_from_slice(&client_finished);
274 state.remote_contribution = Some(remote);
275
276 let session_key = state.compute_session_key(config)?;
278 let session = RatchetSession::new(&session_key, false);
279
280 Self::verify_confirmation(&mut stream, &session_key, true)?;
282 Self::send_confirmation(&mut stream, &session_key, false)?;
283
284 Ok(Self {
285 stream,
286 session,
287 service: config.service.clone(),
288 })
289 }
290
291 pub fn send(&mut self, data: &[u8]) -> Result<()> {
295 if data.len() > MAX_MESSAGE_SIZE {
296 return Err(ShieldError::ChannelError(format!(
297 "message too large: {} > {}",
298 data.len(),
299 MAX_MESSAGE_SIZE
300 )));
301 }
302
303 let encrypted = self.session.encrypt(data)?;
304 Self::write_frame(&mut self.stream, &encrypted)
305 }
306
307 pub fn recv(&mut self) -> Result<Vec<u8>> {
311 let encrypted = Self::read_frame(&mut self.stream)?;
312 self.session.decrypt(&encrypted)
313 }
314
315 #[must_use]
317 pub fn service(&self) -> &str {
318 &self.service
319 }
320
321 #[must_use]
323 pub fn messages_sent(&self) -> u64 {
324 self.session.send_counter()
325 }
326
327 #[must_use]
329 pub fn messages_received(&self) -> u64 {
330 self.session.recv_counter()
331 }
332
333 pub fn into_inner(self) -> S {
335 self.stream
336 }
337
338 fn send_handshake(stream: &mut S, msg_type: HandshakeType, data: &[u8]) -> Result<()> {
341 let mut frame = Vec::with_capacity(1 + 1 + 2 + data.len());
342 frame.push(PROTOCOL_VERSION);
343 frame.push(msg_type as u8);
344 frame.extend_from_slice(&(data.len() as u16).to_be_bytes());
345 frame.extend_from_slice(data);
346
347 stream.write_all(&frame).map_err(|e| ShieldError::ChannelError(e.to_string()))?;
348 stream.flush().map_err(|e| ShieldError::ChannelError(e.to_string()))?;
349 Ok(())
350 }
351
352 fn recv_handshake(stream: &mut S, expected_type: HandshakeType) -> Result<Vec<u8>> {
353 let mut header = [0u8; 4];
354 stream.read_exact(&mut header).map_err(|e| ShieldError::ChannelError(e.to_string()))?;
355
356 if header[0] != PROTOCOL_VERSION {
357 return Err(ShieldError::ChannelError(format!(
358 "unsupported protocol version: {}",
359 header[0]
360 )));
361 }
362
363 if header[1] != expected_type as u8 {
364 return Err(ShieldError::ChannelError(format!(
365 "unexpected message type: expected {}, got {}",
366 expected_type as u8,
367 header[1]
368 )));
369 }
370
371 let len = u16::from_be_bytes([header[2], header[3]]) as usize;
372 if len > 1024 {
373 return Err(ShieldError::ChannelError("handshake message too large".into()));
374 }
375
376 let mut data = vec![0u8; len];
377 stream.read_exact(&mut data).map_err(|e| ShieldError::ChannelError(e.to_string()))?;
378
379 Ok(data)
380 }
381
382 fn send_confirmation(stream: &mut S, session_key: &[u8; 32], is_client: bool) -> Result<()> {
383 let label = if is_client { b"client-confirm" } else { b"server-confirm" };
384 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, session_key);
385 let confirm = hmac::sign(&hmac_key, label);
386
387 Self::write_frame(stream, &confirm.as_ref()[..16])
388 }
389
390 fn verify_confirmation(stream: &mut S, session_key: &[u8; 32], expect_client: bool) -> Result<()> {
391 let received = Self::read_frame(stream)?;
392 if received.len() != 16 {
393 return Err(ShieldError::ChannelError("invalid confirmation".into()));
394 }
395
396 let label = if expect_client { b"client-confirm" } else { b"server-confirm" };
397 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, session_key);
398 let expected = hmac::sign(&hmac_key, label);
399
400 if received.ct_eq(&expected.as_ref()[..16]).unwrap_u8() != 1 {
401 return Err(ShieldError::AuthenticationFailed);
402 }
403
404 Ok(())
405 }
406
407 fn write_frame(stream: &mut S, data: &[u8]) -> Result<()> {
410 let len = data.len() as u32;
411 stream.write_all(&len.to_be_bytes()).map_err(|e| ShieldError::ChannelError(e.to_string()))?;
412 stream.write_all(data).map_err(|e| ShieldError::ChannelError(e.to_string()))?;
413 stream.flush().map_err(|e| ShieldError::ChannelError(e.to_string()))?;
414 Ok(())
415 }
416
417 fn read_frame(stream: &mut S) -> Result<Vec<u8>> {
418 let mut len_buf = [0u8; 4];
419 stream.read_exact(&mut len_buf).map_err(|e| ShieldError::ChannelError(e.to_string()))?;
420
421 let len = u32::from_be_bytes(len_buf) as usize;
422 if len > MAX_MESSAGE_SIZE {
423 return Err(ShieldError::ChannelError(format!(
424 "frame too large: {len} > {MAX_MESSAGE_SIZE}"
425 )));
426 }
427
428 let mut data = vec![0u8; len];
429 stream.read_exact(&mut data).map_err(|e| ShieldError::ChannelError(e.to_string()))?;
430
431 Ok(data)
432 }
433}
434
435pub struct ShieldListener<L> {
437 listener: L,
438 config: ChannelConfig,
439}
440
441impl<L> ShieldListener<L> {
442 #[must_use]
444 pub fn new(listener: L, config: ChannelConfig) -> Self {
445 Self { listener, config }
446 }
447
448 pub fn into_inner(self) -> L {
450 self.listener
451 }
452
453 #[must_use]
455 pub fn config(&self) -> &ChannelConfig {
456 &self.config
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463 use std::sync::mpsc;
464 use std::thread;
465
466 struct MockStream {
468 tx: mpsc::Sender<u8>,
469 rx: mpsc::Receiver<u8>,
470 }
471
472 impl MockStream {
473 fn pair() -> (Self, Self) {
474 let (tx1, rx1) = mpsc::channel();
475 let (tx2, rx2) = mpsc::channel();
476
477 (
478 Self { tx: tx1, rx: rx2 },
479 Self { tx: tx2, rx: rx1 },
480 )
481 }
482 }
483
484 impl Read for MockStream {
485 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
486 for byte in buf.iter_mut() {
487 *byte = self.rx.recv().map_err(|_| {
488 std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "channel closed")
489 })?;
490 }
491 Ok(buf.len())
492 }
493 }
494
495 impl Write for MockStream {
496 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
497 for &byte in buf {
498 self.tx.send(byte).map_err(|_| {
499 std::io::Error::new(std::io::ErrorKind::BrokenPipe, "channel closed")
500 })?;
501 }
502 Ok(buf.len())
503 }
504
505 fn flush(&mut self) -> std::io::Result<()> {
506 Ok(())
507 }
508 }
509
510 #[test]
511 fn test_channel_handshake() {
512 let (client_stream, server_stream) = MockStream::pair();
513 let config = ChannelConfig::new("test-password", "test.service");
514
515 let server_config = config.clone();
516 let server_handle = thread::spawn(move || {
517 ShieldChannel::accept(server_stream, &server_config)
518 });
519
520 let client = ShieldChannel::connect(client_stream, &config).unwrap();
521 let server = server_handle.join().unwrap().unwrap();
522
523 assert_eq!(client.service(), "test.service");
524 assert_eq!(server.service(), "test.service");
525 }
526
527 #[test]
528 fn test_channel_message_exchange() {
529 let (client_stream, server_stream) = MockStream::pair();
530 let config = ChannelConfig::new("secret", "messaging.app");
531
532 let server_config = config.clone();
533 let server_handle = thread::spawn(move || {
534 let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
535 let msg = server.recv().unwrap();
536 server.send(b"Hello client!").unwrap();
537 msg
538 });
539
540 let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
541 client.send(b"Hello server!").unwrap();
542 let response = client.recv().unwrap();
543
544 let server_received = server_handle.join().unwrap();
545
546 assert_eq!(server_received, b"Hello server!");
547 assert_eq!(response, b"Hello client!");
548 }
549
550 #[test]
551 fn test_channel_forward_secrecy() {
552 let (client_stream, server_stream) = MockStream::pair();
553 let config = ChannelConfig::new("password", "service");
554
555 let server_config = config.clone();
556 let server_handle = thread::spawn(move || {
557 let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
558 for _ in 0..3 {
559 let _ = server.recv().unwrap();
560 }
561 server.messages_received()
562 });
563
564 let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
565
566 client.send(b"message 1").unwrap();
568 client.send(b"message 2").unwrap();
569 client.send(b"message 3").unwrap();
570
571 assert_eq!(client.messages_sent(), 3);
572
573 let server_count = server_handle.join().unwrap();
574 assert_eq!(server_count, 3);
575 }
576
577 #[test]
578 fn test_channel_wrong_password() {
579 let (client_stream, server_stream) = MockStream::pair();
580 let client_config = ChannelConfig::new("password1", "service");
581 let server_config = ChannelConfig::new("password2", "service");
582
583 let server_handle = thread::spawn(move || {
584 ShieldChannel::accept(server_stream, &server_config)
585 });
586
587 let client_result = ShieldChannel::connect(client_stream, &client_config);
588 let server_result = server_handle.join().unwrap();
589
590 assert!(client_result.is_err() || server_result.is_err());
592 }
593
594 #[test]
595 fn test_config_builder() {
596 let config = ChannelConfig::new("password", "service")
597 .with_iterations(100_000)
598 .with_timeout(5_000);
599
600 assert_eq!(config.iterations, 100_000);
601 assert_eq!(config.handshake_timeout_ms, 5_000);
602 }
603
604 #[test]
605 fn test_empty_message() {
606 let (client_stream, server_stream) = MockStream::pair();
607 let config = ChannelConfig::new("password", "service");
608
609 let server_config = config.clone();
610 let server_handle = thread::spawn(move || {
611 let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
612 server.recv().unwrap()
613 });
614
615 let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
616 client.send(b"").unwrap();
617
618 let received = server_handle.join().unwrap();
619 assert_eq!(received, b"");
620 }
621
622 #[test]
623 fn test_large_message() {
624 let (client_stream, server_stream) = MockStream::pair();
625 let config = ChannelConfig::new("password", "service");
626
627 let large_data: Vec<u8> = (0..65536).map(|i| (i % 256) as u8).collect();
629
630 let server_config = config.clone();
631 let expected_data = large_data.clone();
632 let server_handle = thread::spawn(move || {
633 let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
634 let received = server.recv().unwrap();
635 assert_eq!(received.len(), expected_data.len());
636 assert_eq!(received, expected_data);
637 });
638
639 let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
640 client.send(&large_data).unwrap();
641
642 server_handle.join().unwrap();
643 }
644
645 #[test]
646 fn test_bidirectional_exchange() {
647 let (client_stream, server_stream) = MockStream::pair();
648 let config = ChannelConfig::new("password", "service");
649
650 let server_config = config.clone();
651 let server_handle = thread::spawn(move || {
652 let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
653
654 server.send(b"Server says hello").unwrap();
656
657 let msg = server.recv().unwrap();
659 assert_eq!(msg, b"Client responds");
660
661 server.send(b"Server acknowledges").unwrap();
663 });
664
665 let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
666
667 let msg1 = client.recv().unwrap();
669 assert_eq!(msg1, b"Server says hello");
670
671 client.send(b"Client responds").unwrap();
673
674 let msg2 = client.recv().unwrap();
676 assert_eq!(msg2, b"Server acknowledges");
677
678 server_handle.join().unwrap();
679 }
680
681 #[test]
682 fn test_different_services_same_password() {
683 let (client_stream, server_stream) = MockStream::pair();
684 let client_config = ChannelConfig::new("password", "service1");
685 let server_config = ChannelConfig::new("password", "service2");
686
687 let server_handle = thread::spawn(move || {
688 ShieldChannel::accept(server_stream, &server_config)
689 });
690
691 let client_result = ShieldChannel::connect(client_stream, &client_config);
692 let server_result = server_handle.join().unwrap();
693
694 if client_result.is_ok() && server_result.is_ok() {
698 assert_eq!(client_result.unwrap().service(), "service1");
700 assert_eq!(server_result.unwrap().service(), "service2");
701 }
702 }
703
704 #[test]
705 fn test_unique_ciphertext_per_message() {
706 let (client_stream, server_stream) = MockStream::pair();
707 let config = ChannelConfig::new("password", "service");
708
709 let server_config = config.clone();
710 let server_handle = thread::spawn(move || {
711 let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
712 let msg1 = server.recv().unwrap();
713 let msg2 = server.recv().unwrap();
714 assert_eq!(msg1, msg2);
716 assert_eq!(msg1, b"same message");
717 });
718
719 let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
720
721 client.send(b"same message").unwrap();
723 client.send(b"same message").unwrap();
724
725 server_handle.join().unwrap();
726 }
727
728 #[test]
729 fn test_listener_config() {
730 let config = ChannelConfig::new("password", "service");
731 let listener = ShieldListener::new((), config.clone());
732
733 assert_eq!(listener.config().service(), "service");
734
735 let _ = listener.into_inner();
736 }
737
738 #[test]
739 fn test_channel_counters() {
740 let (client_stream, server_stream) = MockStream::pair();
741 let config = ChannelConfig::new("password", "service");
742
743 let server_config = config.clone();
744 let server_handle = thread::spawn(move || {
745 let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
746 assert_eq!(server.messages_sent(), 0);
747 assert_eq!(server.messages_received(), 0);
748
749 let _ = server.recv().unwrap();
750 assert_eq!(server.messages_received(), 1);
751
752 server.send(b"response").unwrap();
753 assert_eq!(server.messages_sent(), 1);
754 });
755
756 let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
757 assert_eq!(client.messages_sent(), 0);
758
759 client.send(b"hello").unwrap();
760 assert_eq!(client.messages_sent(), 1);
761
762 let _ = client.recv().unwrap();
763 assert_eq!(client.messages_received(), 1);
764
765 server_handle.join().unwrap();
766 }
767}