1#![allow(clippy::cast_possible_truncation)]
29
30use ring::{
31 hmac,
32 rand::{SecureRandom, SystemRandom},
33};
34use std::io::{Read, Write};
35use subtle::ConstantTimeEq;
36
37use crate::error::{Result, ShieldError};
38use crate::exchange::PAKEExchange;
39use crate::ratchet::RatchetSession;
40
41const PROTOCOL_VERSION: u8 = 1;
43
44const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
46
47#[repr(u8)]
49#[derive(Clone, Copy)]
50enum HandshakeType {
51 ClientHello = 1,
52 ServerHello = 2,
53 Finished = 3,
54}
55
56#[derive(Clone)]
58pub struct ChannelConfig {
59 password: String,
61 service: String,
63 iterations: u32,
65 handshake_timeout_ms: u64,
67}
68
69impl ChannelConfig {
70 #[must_use]
76 pub fn new(password: &str, service: &str) -> Self {
77 Self {
78 password: password.to_string(),
79 service: service.to_string(),
80 iterations: PAKEExchange::ITERATIONS,
81 handshake_timeout_ms: 30_000,
82 }
83 }
84
85 #[must_use]
87 pub fn with_iterations(mut self, iterations: u32) -> Self {
88 self.iterations = iterations;
89 self
90 }
91
92 #[must_use]
94 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
95 self.handshake_timeout_ms = timeout_ms;
96 self
97 }
98
99 #[cfg(feature = "async")]
101 #[must_use]
102 pub(crate) fn password(&self) -> &str {
103 &self.password
104 }
105
106 #[must_use]
108 pub fn service(&self) -> &str {
109 &self.service
110 }
111
112 #[cfg(feature = "async")]
114 #[must_use]
115 pub(crate) fn iterations(&self) -> u32 {
116 self.iterations
117 }
118}
119
120struct HandshakeState {
122 salt: [u8; 16],
123 local_contribution: [u8; 32],
124 remote_contribution: Option<[u8; 32]>,
125 is_initiator: bool,
126}
127
128impl HandshakeState {
129 fn new(is_initiator: bool) -> Result<Self> {
130 let rng = SystemRandom::new();
131
132 let mut salt = [0u8; 16];
133 rng.fill(&mut salt).map_err(|_| ShieldError::RandomFailed)?;
134
135 Ok(Self {
136 salt,
137 local_contribution: [0u8; 32],
138 remote_contribution: None,
139 is_initiator,
140 })
141 }
142
143 fn derive_contribution(&mut self, config: &ChannelConfig) {
144 let role = if self.is_initiator {
145 "client"
146 } else {
147 "server"
148 };
149 self.local_contribution =
150 PAKEExchange::derive(&config.password, &self.salt, role, Some(config.iterations));
151 }
152
153 fn compute_session_key(&self, config: &ChannelConfig) -> Result<[u8; 32]> {
154 let remote = self
155 .remote_contribution
156 .ok_or_else(|| ShieldError::ChannelError("handshake incomplete".into()))?;
157
158 let base_key = PAKEExchange::combine(&[self.local_contribution, remote]);
162
163 let password_key = PAKEExchange::derive(
165 &config.password,
166 &self.salt,
167 "session",
168 Some(config.iterations),
169 );
170
171 let mut combined = Vec::with_capacity(64);
173 combined.extend_from_slice(&base_key);
174 combined.extend_from_slice(&password_key);
175
176 let hash = ring::digest::digest(&ring::digest::SHA256, &combined);
177 let mut result = [0u8; 32];
178 result.copy_from_slice(hash.as_ref());
179 Ok(result)
180 }
181}
182
183pub struct ShieldChannel<S> {
190 stream: S,
191 session: RatchetSession,
192 service: String,
193}
194
195impl<S: Read + Write> ShieldChannel<S> {
196 pub fn connect(mut stream: S, config: &ChannelConfig) -> Result<Self> {
204 let mut state = HandshakeState::new(true)?;
205
206 Self::send_handshake(&mut stream, HandshakeType::ClientHello, &state.salt)?;
208
209 let server_hello = Self::recv_handshake(&mut stream, HandshakeType::ServerHello)?;
211 if server_hello.len() != 48 {
212 return Err(ShieldError::ChannelError("invalid ServerHello".into()));
213 }
214
215 state.salt.copy_from_slice(&server_hello[..16]);
217 state.derive_contribution(config);
218
219 let mut remote = [0u8; 32];
220 remote.copy_from_slice(&server_hello[16..48]);
221 state.remote_contribution = Some(remote);
222
223 Self::send_handshake(
225 &mut stream,
226 HandshakeType::Finished,
227 &state.local_contribution,
228 )?;
229
230 let session_key = state.compute_session_key(config)?;
232 let session = RatchetSession::new(&session_key, true);
233
234 Self::send_confirmation(&mut stream, &session_key, true)?;
236 Self::verify_confirmation(&mut stream, &session_key, false)?;
237
238 Ok(Self {
239 stream,
240 session,
241 service: config.service.clone(),
242 })
243 }
244
245 pub fn accept(mut stream: S, config: &ChannelConfig) -> Result<Self> {
253 let mut state = HandshakeState::new(false)?;
254
255 let client_hello = Self::recv_handshake(&mut stream, HandshakeType::ClientHello)?;
257 if client_hello.len() != 16 {
258 return Err(ShieldError::ChannelError("invalid ClientHello".into()));
259 }
260
261 for (i, &b) in client_hello.iter().enumerate() {
263 state.salt[i] ^= b;
264 }
265
266 state.derive_contribution(config);
267
268 let mut server_hello = Vec::with_capacity(48);
270 server_hello.extend_from_slice(&state.salt);
271 server_hello.extend_from_slice(&state.local_contribution);
272 Self::send_handshake(&mut stream, HandshakeType::ServerHello, &server_hello)?;
273
274 let client_finished = Self::recv_handshake(&mut stream, HandshakeType::Finished)?;
276 if client_finished.len() != 32 {
277 return Err(ShieldError::ChannelError("invalid Finished".into()));
278 }
279
280 let mut remote = [0u8; 32];
281 remote.copy_from_slice(&client_finished);
282 state.remote_contribution = Some(remote);
283
284 let session_key = state.compute_session_key(config)?;
286 let session = RatchetSession::new(&session_key, false);
287
288 Self::verify_confirmation(&mut stream, &session_key, true)?;
290 Self::send_confirmation(&mut stream, &session_key, false)?;
291
292 Ok(Self {
293 stream,
294 session,
295 service: config.service.clone(),
296 })
297 }
298
299 pub fn send(&mut self, data: &[u8]) -> Result<()> {
303 if data.len() > MAX_MESSAGE_SIZE {
304 return Err(ShieldError::ChannelError(format!(
305 "message too large: {} > {}",
306 data.len(),
307 MAX_MESSAGE_SIZE
308 )));
309 }
310
311 let encrypted = self.session.encrypt(data)?;
312 Self::write_frame(&mut self.stream, &encrypted)
313 }
314
315 pub fn recv(&mut self) -> Result<Vec<u8>> {
319 let encrypted = Self::read_frame(&mut self.stream)?;
320 self.session.decrypt(&encrypted)
321 }
322
323 #[must_use]
325 pub fn service(&self) -> &str {
326 &self.service
327 }
328
329 #[must_use]
331 pub fn messages_sent(&self) -> u64 {
332 self.session.send_counter()
333 }
334
335 #[must_use]
337 pub fn messages_received(&self) -> u64 {
338 self.session.recv_counter()
339 }
340
341 pub fn into_inner(self) -> S {
343 self.stream
344 }
345
346 fn send_handshake(stream: &mut S, msg_type: HandshakeType, data: &[u8]) -> Result<()> {
349 let mut frame = Vec::with_capacity(1 + 1 + 2 + data.len());
350 frame.push(PROTOCOL_VERSION);
351 frame.push(msg_type as u8);
352 frame.extend_from_slice(&(data.len() as u16).to_be_bytes());
353 frame.extend_from_slice(data);
354
355 stream
356 .write_all(&frame)
357 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
358 stream
359 .flush()
360 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
361 Ok(())
362 }
363
364 fn recv_handshake(stream: &mut S, expected_type: HandshakeType) -> Result<Vec<u8>> {
365 let mut header = [0u8; 4];
366 stream
367 .read_exact(&mut header)
368 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
369
370 if header[0] != PROTOCOL_VERSION {
371 return Err(ShieldError::ChannelError(format!(
372 "unsupported protocol version: {}",
373 header[0]
374 )));
375 }
376
377 if header[1] != expected_type as u8 {
378 return Err(ShieldError::ChannelError(format!(
379 "unexpected message type: expected {}, got {}",
380 expected_type as u8, header[1]
381 )));
382 }
383
384 let len = u16::from_be_bytes([header[2], header[3]]) as usize;
385 if len > 1024 {
386 return Err(ShieldError::ChannelError(
387 "handshake message too large".into(),
388 ));
389 }
390
391 let mut data = vec![0u8; len];
392 stream
393 .read_exact(&mut data)
394 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
395
396 Ok(data)
397 }
398
399 fn send_confirmation(stream: &mut S, session_key: &[u8; 32], is_client: bool) -> Result<()> {
400 let label = if is_client {
401 b"client-confirm"
402 } else {
403 b"server-confirm"
404 };
405 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, session_key);
406 let confirm = hmac::sign(&hmac_key, label);
407
408 Self::write_frame(stream, &confirm.as_ref()[..16])
409 }
410
411 fn verify_confirmation(
412 stream: &mut S,
413 session_key: &[u8; 32],
414 expect_client: bool,
415 ) -> Result<()> {
416 let received = Self::read_frame(stream)?;
417 if received.len() != 16 {
418 return Err(ShieldError::ChannelError("invalid confirmation".into()));
419 }
420
421 let label = if expect_client {
422 b"client-confirm"
423 } else {
424 b"server-confirm"
425 };
426 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, session_key);
427 let expected = hmac::sign(&hmac_key, label);
428
429 if received.ct_eq(&expected.as_ref()[..16]).unwrap_u8() != 1 {
430 return Err(ShieldError::AuthenticationFailed);
431 }
432
433 Ok(())
434 }
435
436 fn write_frame(stream: &mut S, data: &[u8]) -> Result<()> {
439 let len = data.len() as u32;
440 stream
441 .write_all(&len.to_be_bytes())
442 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
443 stream
444 .write_all(data)
445 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
446 stream
447 .flush()
448 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
449 Ok(())
450 }
451
452 fn read_frame(stream: &mut S) -> Result<Vec<u8>> {
453 let mut len_buf = [0u8; 4];
454 stream
455 .read_exact(&mut len_buf)
456 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
457
458 let len = u32::from_be_bytes(len_buf) as usize;
459 if len > MAX_MESSAGE_SIZE {
460 return Err(ShieldError::ChannelError(format!(
461 "frame too large: {len} > {MAX_MESSAGE_SIZE}"
462 )));
463 }
464
465 let mut data = vec![0u8; len];
466 stream
467 .read_exact(&mut data)
468 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
469
470 Ok(data)
471 }
472}
473
474pub struct ShieldListener<L> {
476 listener: L,
477 config: ChannelConfig,
478}
479
480impl<L> ShieldListener<L> {
481 #[must_use]
483 pub fn new(listener: L, config: ChannelConfig) -> Self {
484 Self { listener, config }
485 }
486
487 pub fn into_inner(self) -> L {
489 self.listener
490 }
491
492 #[must_use]
494 pub fn config(&self) -> &ChannelConfig {
495 &self.config
496 }
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502 use std::sync::mpsc;
503 use std::thread;
504
505 struct MockStream {
507 tx: mpsc::Sender<u8>,
508 rx: mpsc::Receiver<u8>,
509 }
510
511 impl MockStream {
512 fn pair() -> (Self, Self) {
513 let (tx1, rx1) = mpsc::channel();
514 let (tx2, rx2) = mpsc::channel();
515
516 (Self { tx: tx1, rx: rx2 }, Self { tx: tx2, rx: rx1 })
517 }
518 }
519
520 impl Read for MockStream {
521 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
522 for byte in buf.iter_mut() {
523 *byte = self.rx.recv().map_err(|_| {
524 std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "channel closed")
525 })?;
526 }
527 Ok(buf.len())
528 }
529 }
530
531 impl Write for MockStream {
532 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
533 for &byte in buf {
534 self.tx.send(byte).map_err(|_| {
535 std::io::Error::new(std::io::ErrorKind::BrokenPipe, "channel closed")
536 })?;
537 }
538 Ok(buf.len())
539 }
540
541 fn flush(&mut self) -> std::io::Result<()> {
542 Ok(())
543 }
544 }
545
546 #[test]
547 fn test_channel_handshake() {
548 let (client_stream, server_stream) = MockStream::pair();
549 let config = ChannelConfig::new("test-password", "test.service");
550
551 let server_config = config.clone();
552 let server_handle =
553 thread::spawn(move || ShieldChannel::accept(server_stream, &server_config));
554
555 let client = ShieldChannel::connect(client_stream, &config).unwrap();
556 let server = server_handle.join().unwrap().unwrap();
557
558 assert_eq!(client.service(), "test.service");
559 assert_eq!(server.service(), "test.service");
560 }
561
562 #[test]
563 fn test_channel_message_exchange() {
564 let (client_stream, server_stream) = MockStream::pair();
565 let config = ChannelConfig::new("secret", "messaging.app");
566
567 let server_config = config.clone();
568 let server_handle = thread::spawn(move || {
569 let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
570 let msg = server.recv().unwrap();
571 server.send(b"Hello client!").unwrap();
572 msg
573 });
574
575 let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
576 client.send(b"Hello server!").unwrap();
577 let response = client.recv().unwrap();
578
579 let server_received = server_handle.join().unwrap();
580
581 assert_eq!(server_received, b"Hello server!");
582 assert_eq!(response, b"Hello client!");
583 }
584
585 #[test]
586 fn test_channel_forward_secrecy() {
587 let (client_stream, server_stream) = MockStream::pair();
588 let config = ChannelConfig::new("password", "service");
589
590 let server_config = config.clone();
591 let server_handle = thread::spawn(move || {
592 let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
593 for _ in 0..3 {
594 let _ = server.recv().unwrap();
595 }
596 server.messages_received()
597 });
598
599 let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
600
601 client.send(b"message 1").unwrap();
603 client.send(b"message 2").unwrap();
604 client.send(b"message 3").unwrap();
605
606 assert_eq!(client.messages_sent(), 3);
607
608 let server_count = server_handle.join().unwrap();
609 assert_eq!(server_count, 3);
610 }
611
612 #[test]
613 fn test_channel_wrong_password() {
614 let (client_stream, server_stream) = MockStream::pair();
615 let client_config = ChannelConfig::new("password1", "service");
616 let server_config = ChannelConfig::new("password2", "service");
617
618 let server_handle =
619 thread::spawn(move || ShieldChannel::accept(server_stream, &server_config));
620
621 let client_result = ShieldChannel::connect(client_stream, &client_config);
622 let server_result = server_handle.join().unwrap();
623
624 assert!(client_result.is_err() || server_result.is_err());
626 }
627
628 #[test]
629 fn test_config_builder() {
630 let config = ChannelConfig::new("password", "service")
631 .with_iterations(100_000)
632 .with_timeout(5_000);
633
634 assert_eq!(config.iterations, 100_000);
635 assert_eq!(config.handshake_timeout_ms, 5_000);
636 }
637
638 #[test]
639 fn test_empty_message() {
640 let (client_stream, server_stream) = MockStream::pair();
641 let config = ChannelConfig::new("password", "service");
642
643 let server_config = config.clone();
644 let server_handle = thread::spawn(move || {
645 let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
646 server.recv().unwrap()
647 });
648
649 let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
650 client.send(b"").unwrap();
651
652 let received = server_handle.join().unwrap();
653 assert_eq!(received, b"");
654 }
655
656 #[test]
657 fn test_large_message() {
658 let (client_stream, server_stream) = MockStream::pair();
659 let config = ChannelConfig::new("password", "service");
660
661 let large_data: Vec<u8> = (0..65536).map(|i| (i % 256) as u8).collect();
663
664 let server_config = config.clone();
665 let expected_data = large_data.clone();
666 let server_handle = thread::spawn(move || {
667 let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
668 let received = server.recv().unwrap();
669 assert_eq!(received.len(), expected_data.len());
670 assert_eq!(received, expected_data);
671 });
672
673 let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
674 client.send(&large_data).unwrap();
675
676 server_handle.join().unwrap();
677 }
678
679 #[test]
680 fn test_bidirectional_exchange() {
681 let (client_stream, server_stream) = MockStream::pair();
682 let config = ChannelConfig::new("password", "service");
683
684 let server_config = config.clone();
685 let server_handle = thread::spawn(move || {
686 let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
687
688 server.send(b"Server says hello").unwrap();
690
691 let msg = server.recv().unwrap();
693 assert_eq!(msg, b"Client responds");
694
695 server.send(b"Server acknowledges").unwrap();
697 });
698
699 let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
700
701 let msg1 = client.recv().unwrap();
703 assert_eq!(msg1, b"Server says hello");
704
705 client.send(b"Client responds").unwrap();
707
708 let msg2 = client.recv().unwrap();
710 assert_eq!(msg2, b"Server acknowledges");
711
712 server_handle.join().unwrap();
713 }
714
715 #[test]
716 fn test_different_services_same_password() {
717 let (client_stream, server_stream) = MockStream::pair();
718 let client_config = ChannelConfig::new("password", "service1");
719 let server_config = ChannelConfig::new("password", "service2");
720
721 let server_handle =
722 thread::spawn(move || ShieldChannel::accept(server_stream, &server_config));
723
724 let client_result = ShieldChannel::connect(client_stream, &client_config);
725 let server_result = server_handle.join().unwrap();
726
727 if client_result.is_ok() && server_result.is_ok() {
731 assert_eq!(client_result.unwrap().service(), "service1");
733 assert_eq!(server_result.unwrap().service(), "service2");
734 }
735 }
736
737 #[test]
738 fn test_unique_ciphertext_per_message() {
739 let (client_stream, server_stream) = MockStream::pair();
740 let config = ChannelConfig::new("password", "service");
741
742 let server_config = config.clone();
743 let server_handle = thread::spawn(move || {
744 let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
745 let msg1 = server.recv().unwrap();
746 let msg2 = server.recv().unwrap();
747 assert_eq!(msg1, msg2);
749 assert_eq!(msg1, b"same message");
750 });
751
752 let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
753
754 client.send(b"same message").unwrap();
756 client.send(b"same message").unwrap();
757
758 server_handle.join().unwrap();
759 }
760
761 #[test]
762 fn test_listener_config() {
763 let config = ChannelConfig::new("password", "service");
764 let listener = ShieldListener::new((), config.clone());
765
766 assert_eq!(listener.config().service(), "service");
767
768 let _ = listener.into_inner();
769 }
770
771 #[test]
772 fn test_channel_counters() {
773 let (client_stream, server_stream) = MockStream::pair();
774 let config = ChannelConfig::new("password", "service");
775
776 let server_config = config.clone();
777 let server_handle = thread::spawn(move || {
778 let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
779 assert_eq!(server.messages_sent(), 0);
780 assert_eq!(server.messages_received(), 0);
781
782 let _ = server.recv().unwrap();
783 assert_eq!(server.messages_received(), 1);
784
785 server.send(b"response").unwrap();
786 assert_eq!(server.messages_sent(), 1);
787 });
788
789 let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
790 assert_eq!(client.messages_sent(), 0);
791
792 client.send(b"hello").unwrap();
793 assert_eq!(client.messages_sent(), 1);
794
795 let _ = client.recv().unwrap();
796 assert_eq!(client.messages_received(), 1);
797
798 server_handle.join().unwrap();
799 }
800}