1#![allow(clippy::cast_possible_truncation)]
39
40use ring::hmac;
41use std::io::{Read, Write};
42use subtle::ConstantTimeEq;
43use zeroize::Zeroize;
44
45use crate::error::{Result, ShieldError};
46use crate::exchange::PAKEExchange;
47use crate::ratchet::RatchetSession;
48
49const PROTOCOL_VERSION: u8 = 1;
51
52const MAX_MESSAGE_SIZE_CAP: usize = 16 * 1024 * 1024;
54
55const DEFAULT_MAX_MESSAGE_SIZE: usize = 1024 * 1024;
57
58#[repr(u8)]
60#[derive(Clone, Copy)]
61enum HandshakeType {
62 ClientHello = 1,
63 ServerHello = 2,
64 Finished = 3,
65}
66
67#[derive(Clone)]
69pub struct ChannelConfig {
70 password: String,
72 service: String,
74 iterations: u32,
76 handshake_timeout_ms: u64,
78 max_message_size: usize,
80}
81
82impl ChannelConfig {
83 #[must_use]
89 pub fn new(password: &str, service: &str) -> Self {
90 Self {
91 password: password.to_string(),
92 service: service.to_string(),
93 iterations: PAKEExchange::ITERATIONS,
94 handshake_timeout_ms: 30_000,
95 max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
96 }
97 }
98
99 #[must_use]
101 pub fn with_iterations(mut self, iterations: u32) -> Self {
102 self.iterations = iterations;
103 self
104 }
105
106 #[must_use]
108 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
109 self.handshake_timeout_ms = timeout_ms;
110 self
111 }
112
113 #[must_use]
115 pub fn with_max_message_size(mut self, size: usize) -> Self {
116 self.max_message_size = size.min(MAX_MESSAGE_SIZE_CAP);
117 self
118 }
119
120 #[cfg(feature = "async")]
122 #[must_use]
123 pub(crate) fn password(&self) -> &str {
124 &self.password
125 }
126
127 #[must_use]
129 pub fn service(&self) -> &str {
130 &self.service
131 }
132
133 #[cfg(feature = "async")]
135 #[must_use]
136 pub(crate) fn iterations(&self) -> u32 {
137 self.iterations
138 }
139
140 #[cfg(feature = "async")]
142 #[must_use]
143 pub(crate) fn handshake_timeout_ms(&self) -> u64 {
144 self.handshake_timeout_ms
145 }
146
147 #[must_use]
149 pub fn max_message_size(&self) -> usize {
150 self.max_message_size
151 }
152}
153
154struct HandshakeState {
156 salt: [u8; 16],
157 local_contribution: [u8; 32],
158 remote_contribution: Option<[u8; 32]>,
159 is_initiator: bool,
160}
161
162impl Drop for HandshakeState {
163 fn drop(&mut self) {
164 self.salt.zeroize();
165 self.local_contribution.zeroize();
166 if let Some(ref mut remote) = self.remote_contribution {
167 remote.zeroize();
168 }
169 }
170}
171
172impl HandshakeState {
173 fn new(is_initiator: bool) -> Result<Self> {
174 let salt: [u8; 16] = crate::random::random_bytes()?;
175
176 Ok(Self {
177 salt,
178 local_contribution: [0u8; 32],
179 remote_contribution: None,
180 is_initiator,
181 })
182 }
183
184 fn derive_contribution(&mut self, config: &ChannelConfig) {
185 let role = if self.is_initiator {
186 "client"
187 } else {
188 "server"
189 };
190 self.local_contribution =
191 PAKEExchange::derive(&config.password, &self.salt, role, Some(config.iterations));
192 }
193
194 fn compute_session_key(&self, config: &ChannelConfig) -> Result<[u8; 32]> {
195 let remote = self
196 .remote_contribution
197 .ok_or_else(|| ShieldError::ChannelError("handshake incomplete".into()))?;
198
199 let base_key = PAKEExchange::combine(&[self.local_contribution, remote]);
203
204 let password_key = PAKEExchange::derive(
206 &config.password,
207 &self.salt,
208 "session",
209 Some(config.iterations),
210 );
211
212 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, &base_key);
215 let tag = hmac::sign(&hmac_key, &password_key);
216 let mut result = [0u8; 32];
217 result.copy_from_slice(&tag.as_ref()[..32]);
218 Ok(result)
219 }
220}
221
222pub struct ShieldChannel<S> {
229 stream: S,
230 session: RatchetSession,
231 service: String,
232 max_message_size: usize,
233}
234
235impl<S: Read + Write> ShieldChannel<S> {
236 pub fn connect(mut stream: S, config: &ChannelConfig) -> Result<Self> {
244 let mut state = HandshakeState::new(true)?;
245
246 Self::send_handshake(&mut stream, HandshakeType::ClientHello, &state.salt)?;
248
249 let server_hello = Self::recv_handshake(&mut stream, HandshakeType::ServerHello)?;
251 if server_hello.len() != 48 {
252 return Err(ShieldError::ChannelError("invalid ServerHello".into()));
253 }
254
255 state.salt.copy_from_slice(&server_hello[..16]);
257 state.derive_contribution(config);
258
259 let mut remote = [0u8; 32];
260 remote.copy_from_slice(&server_hello[16..48]);
261 state.remote_contribution = Some(remote);
262
263 Self::send_handshake(
265 &mut stream,
266 HandshakeType::Finished,
267 &state.local_contribution,
268 )?;
269
270 let session_key = state.compute_session_key(config)?;
272 let session = RatchetSession::new(&session_key, true);
273
274 Self::send_confirmation(&mut stream, &session_key, true)?;
276 Self::verify_confirmation(&mut stream, &session_key, false)?;
277
278 Ok(Self {
279 stream,
280 session,
281 service: config.service.clone(),
282 max_message_size: config.max_message_size,
283 })
284 }
285
286 pub fn accept(mut stream: S, config: &ChannelConfig) -> Result<Self> {
294 let mut state = HandshakeState::new(false)?;
295
296 let client_hello = Self::recv_handshake(&mut stream, HandshakeType::ClientHello)?;
298 if client_hello.len() != 16 {
299 return Err(ShieldError::ChannelError("invalid ClientHello".into()));
300 }
301
302 for (i, &b) in client_hello.iter().enumerate() {
304 state.salt[i] ^= b;
305 }
306
307 state.derive_contribution(config);
308
309 let mut server_hello = Vec::with_capacity(48);
311 server_hello.extend_from_slice(&state.salt);
312 server_hello.extend_from_slice(&state.local_contribution);
313 Self::send_handshake(&mut stream, HandshakeType::ServerHello, &server_hello)?;
314
315 let client_finished = Self::recv_handshake(&mut stream, HandshakeType::Finished)?;
317 if client_finished.len() != 32 {
318 return Err(ShieldError::ChannelError("invalid Finished".into()));
319 }
320
321 let mut remote = [0u8; 32];
322 remote.copy_from_slice(&client_finished);
323 state.remote_contribution = Some(remote);
324
325 let session_key = state.compute_session_key(config)?;
327 let session = RatchetSession::new(&session_key, false);
328
329 Self::verify_confirmation(&mut stream, &session_key, true)?;
331 Self::send_confirmation(&mut stream, &session_key, false)?;
332
333 Ok(Self {
334 stream,
335 session,
336 service: config.service.clone(),
337 max_message_size: config.max_message_size,
338 })
339 }
340
341 pub fn send(&mut self, data: &[u8]) -> Result<()> {
345 if data.len() > self.max_message_size {
346 return Err(ShieldError::ChannelError(format!(
347 "message too large: {} > {}",
348 data.len(),
349 self.max_message_size
350 )));
351 }
352
353 let encrypted = self.session.encrypt(data)?;
354 Self::write_frame(&mut self.stream, &encrypted, self.max_message_size)
355 }
356
357 pub fn recv(&mut self) -> Result<Vec<u8>> {
361 let encrypted = Self::read_frame(&mut self.stream, self.max_message_size)?;
362 self.session.decrypt(&encrypted)
363 }
364
365 #[must_use]
367 pub fn service(&self) -> &str {
368 &self.service
369 }
370
371 #[must_use]
373 pub fn messages_sent(&self) -> u64 {
374 self.session.send_counter()
375 }
376
377 #[must_use]
379 pub fn messages_received(&self) -> u64 {
380 self.session.recv_counter()
381 }
382
383 pub fn into_inner(self) -> S {
385 self.stream
386 }
387
388 fn send_handshake(stream: &mut S, msg_type: HandshakeType, data: &[u8]) -> Result<()> {
391 let mut frame = Vec::with_capacity(1 + 1 + 2 + data.len());
392 frame.push(PROTOCOL_VERSION);
393 frame.push(msg_type as u8);
394 frame.extend_from_slice(&(data.len() as u16).to_be_bytes());
395 frame.extend_from_slice(data);
396
397 stream
398 .write_all(&frame)
399 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
400 stream
401 .flush()
402 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
403 Ok(())
404 }
405
406 fn recv_handshake(stream: &mut S, expected_type: HandshakeType) -> Result<Vec<u8>> {
407 let mut header = [0u8; 4];
408 stream
409 .read_exact(&mut header)
410 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
411
412 if header[0] != PROTOCOL_VERSION {
413 return Err(ShieldError::ChannelError(format!(
414 "unsupported protocol version: {}",
415 header[0]
416 )));
417 }
418
419 if header[1] != expected_type as u8 {
420 return Err(ShieldError::ChannelError(format!(
421 "unexpected message type: expected {}, got {}",
422 expected_type as u8, header[1]
423 )));
424 }
425
426 let len = u16::from_be_bytes([header[2], header[3]]) as usize;
427 if len > 1024 {
428 return Err(ShieldError::ChannelError(
429 "handshake message too large".into(),
430 ));
431 }
432
433 let mut data = vec![0u8; len];
434 stream
435 .read_exact(&mut data)
436 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
437
438 Ok(data)
439 }
440
441 fn send_confirmation(stream: &mut S, session_key: &[u8; 32], is_client: bool) -> Result<()> {
442 let label = if is_client {
443 b"client-confirm"
444 } else {
445 b"server-confirm"
446 };
447 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, session_key);
448 let confirm = hmac::sign(&hmac_key, label);
449
450 Self::write_frame(stream, &confirm.as_ref()[..16], DEFAULT_MAX_MESSAGE_SIZE)
451 }
452
453 fn verify_confirmation(
454 stream: &mut S,
455 session_key: &[u8; 32],
456 expect_client: bool,
457 ) -> Result<()> {
458 let received = Self::read_frame(stream, DEFAULT_MAX_MESSAGE_SIZE)?;
459 if received.len() != 16 {
460 return Err(ShieldError::ChannelError("invalid confirmation".into()));
461 }
462
463 let label = if expect_client {
464 b"client-confirm"
465 } else {
466 b"server-confirm"
467 };
468 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, session_key);
469 let expected = hmac::sign(&hmac_key, label);
470
471 if received.ct_eq(&expected.as_ref()[..16]).unwrap_u8() != 1 {
472 return Err(ShieldError::AuthenticationFailed);
473 }
474
475 Ok(())
476 }
477
478 fn write_frame(stream: &mut S, data: &[u8], max_size: usize) -> Result<()> {
481 if data.len() > max_size {
482 return Err(ShieldError::ChannelError(format!(
483 "frame too large to send: {} > {max_size}",
484 data.len()
485 )));
486 }
487 let len = data.len() as u32;
488 stream
489 .write_all(&len.to_be_bytes())
490 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
491 stream
492 .write_all(data)
493 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
494 stream
495 .flush()
496 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
497 Ok(())
498 }
499
500 fn read_frame(stream: &mut S, max_size: usize) -> Result<Vec<u8>> {
501 let mut len_buf = [0u8; 4];
502 stream
503 .read_exact(&mut len_buf)
504 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
505
506 let len = u32::from_be_bytes(len_buf) as usize;
507 if len > max_size {
508 return Err(ShieldError::ChannelError(format!(
509 "frame too large: {len} > {max_size}"
510 )));
511 }
512
513 let mut data = vec![0u8; len];
514 stream
515 .read_exact(&mut data)
516 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
517
518 Ok(data)
519 }
520}
521
522impl ShieldChannel<std::net::TcpStream> {
523 pub fn connect_tcp(stream: std::net::TcpStream, config: &ChannelConfig) -> Result<Self> {
528 let timeout = std::time::Duration::from_millis(config.handshake_timeout_ms);
529 stream
530 .set_read_timeout(Some(timeout))
531 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
532 stream
533 .set_write_timeout(Some(timeout))
534 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
535
536 match Self::connect(stream, config) {
537 Ok(channel) => {
538 channel.stream.set_read_timeout(None).ok();
539 channel.stream.set_write_timeout(None).ok();
540 Ok(channel)
541 }
542 Err(e) => Err(e),
543 }
544 }
545
546 pub fn accept_tcp(stream: std::net::TcpStream, config: &ChannelConfig) -> Result<Self> {
551 let timeout = std::time::Duration::from_millis(config.handshake_timeout_ms);
552 stream
553 .set_read_timeout(Some(timeout))
554 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
555 stream
556 .set_write_timeout(Some(timeout))
557 .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
558
559 match Self::accept(stream, config) {
560 Ok(channel) => {
561 channel.stream.set_read_timeout(None).ok();
562 channel.stream.set_write_timeout(None).ok();
563 Ok(channel)
564 }
565 Err(e) => Err(e),
566 }
567 }
568}
569
570pub struct ShieldListener<L> {
572 listener: L,
573 config: ChannelConfig,
574}
575
576impl<L> ShieldListener<L> {
577 #[must_use]
579 pub fn new(listener: L, config: ChannelConfig) -> Self {
580 Self { listener, config }
581 }
582
583 pub fn into_inner(self) -> L {
585 self.listener
586 }
587
588 #[must_use]
590 pub fn config(&self) -> &ChannelConfig {
591 &self.config
592 }
593}
594
595#[cfg(test)]
596mod tests {
597 use super::*;
598 use std::sync::mpsc;
599 use std::thread;
600
601 struct MockStream {
603 tx: mpsc::Sender<u8>,
604 rx: mpsc::Receiver<u8>,
605 }
606
607 impl MockStream {
608 fn pair() -> (Self, Self) {
609 let (tx1, rx1) = mpsc::channel();
610 let (tx2, rx2) = mpsc::channel();
611
612 (Self { tx: tx1, rx: rx2 }, Self { tx: tx2, rx: rx1 })
613 }
614 }
615
616 impl Read for MockStream {
617 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
618 for byte in buf.iter_mut() {
619 *byte = self.rx.recv().map_err(|_| {
620 std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "channel closed")
621 })?;
622 }
623 Ok(buf.len())
624 }
625 }
626
627 impl Write for MockStream {
628 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
629 for &byte in buf {
630 self.tx.send(byte).map_err(|_| {
631 std::io::Error::new(std::io::ErrorKind::BrokenPipe, "channel closed")
632 })?;
633 }
634 Ok(buf.len())
635 }
636
637 fn flush(&mut self) -> std::io::Result<()> {
638 Ok(())
639 }
640 }
641
642 #[test]
643 fn test_channel_handshake() {
644 let (client_stream, server_stream) = MockStream::pair();
645 let config = ChannelConfig::new("test-password", "test.service");
646
647 let server_config = config.clone();
648 let server_handle =
649 thread::spawn(move || ShieldChannel::accept(server_stream, &server_config));
650
651 let client = ShieldChannel::connect(client_stream, &config).unwrap();
652 let server = server_handle.join().unwrap().unwrap();
653
654 assert_eq!(client.service(), "test.service");
655 assert_eq!(server.service(), "test.service");
656 }
657
658 #[test]
659 fn test_channel_message_exchange() {
660 let (client_stream, server_stream) = MockStream::pair();
661 let config = ChannelConfig::new("secret", "messaging.app");
662
663 let server_config = config.clone();
664 let server_handle = thread::spawn(move || {
665 let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
666 let msg = server.recv().unwrap();
667 server.send(b"Hello client!").unwrap();
668 msg
669 });
670
671 let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
672 client.send(b"Hello server!").unwrap();
673 let response = client.recv().unwrap();
674
675 let server_received = server_handle.join().unwrap();
676
677 assert_eq!(server_received, b"Hello server!");
678 assert_eq!(response, b"Hello client!");
679 }
680
681 #[test]
682 fn test_channel_forward_secrecy() {
683 let (client_stream, server_stream) = MockStream::pair();
684 let config = ChannelConfig::new("password", "service");
685
686 let server_config = config.clone();
687 let server_handle = thread::spawn(move || {
688 let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
689 for _ in 0..3 {
690 let _ = server.recv().unwrap();
691 }
692 server.messages_received()
693 });
694
695 let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
696
697 client.send(b"message 1").unwrap();
699 client.send(b"message 2").unwrap();
700 client.send(b"message 3").unwrap();
701
702 assert_eq!(client.messages_sent(), 3);
703
704 let server_count = server_handle.join().unwrap();
705 assert_eq!(server_count, 3);
706 }
707
708 #[test]
709 fn test_channel_wrong_password() {
710 let (client_stream, server_stream) = MockStream::pair();
711 let client_config = ChannelConfig::new("password1", "service");
712 let server_config = ChannelConfig::new("password2", "service");
713
714 let server_handle =
715 thread::spawn(move || ShieldChannel::accept(server_stream, &server_config));
716
717 let client_result = ShieldChannel::connect(client_stream, &client_config);
718 let server_result = server_handle.join().unwrap();
719
720 assert!(client_result.is_err() || server_result.is_err());
722 }
723
724 #[test]
725 fn test_config_builder() {
726 let config = ChannelConfig::new("password", "service")
727 .with_iterations(100_000)
728 .with_timeout(5_000);
729
730 assert_eq!(config.iterations, 100_000);
731 assert_eq!(config.handshake_timeout_ms, 5_000);
732 }
733
734 #[test]
735 fn test_empty_message() {
736 let (client_stream, server_stream) = MockStream::pair();
737 let config = ChannelConfig::new("password", "service");
738
739 let server_config = config.clone();
740 let server_handle = thread::spawn(move || {
741 let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
742 server.recv().unwrap()
743 });
744
745 let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
746 client.send(b"").unwrap();
747
748 let received = server_handle.join().unwrap();
749 assert_eq!(received, b"");
750 }
751
752 #[test]
753 fn test_large_message() {
754 let (client_stream, server_stream) = MockStream::pair();
755 let config = ChannelConfig::new("password", "service");
756
757 let large_data: Vec<u8> = (0..65536_u32).map(|i| (i % 256) as u8).collect();
759
760 let server_config = config.clone();
761 let expected_data = large_data.clone();
762 let server_handle = thread::spawn(move || {
763 let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
764 let received = server.recv().unwrap();
765 assert_eq!(received.len(), expected_data.len());
766 assert_eq!(received, expected_data);
767 });
768
769 let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
770 client.send(&large_data).unwrap();
771
772 server_handle.join().unwrap();
773 }
774
775 #[test]
776 fn test_bidirectional_exchange() {
777 let (client_stream, server_stream) = MockStream::pair();
778 let config = ChannelConfig::new("password", "service");
779
780 let server_config = config.clone();
781 let server_handle = thread::spawn(move || {
782 let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
783
784 server.send(b"Server says hello").unwrap();
786
787 let msg = server.recv().unwrap();
789 assert_eq!(msg, b"Client responds");
790
791 server.send(b"Server acknowledges").unwrap();
793 });
794
795 let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
796
797 let msg1 = client.recv().unwrap();
799 assert_eq!(msg1, b"Server says hello");
800
801 client.send(b"Client responds").unwrap();
803
804 let msg2 = client.recv().unwrap();
806 assert_eq!(msg2, b"Server acknowledges");
807
808 server_handle.join().unwrap();
809 }
810
811 #[test]
812 fn test_different_services_same_password() {
813 let (client_stream, server_stream) = MockStream::pair();
814 let client_config = ChannelConfig::new("password", "service1");
815 let server_config = ChannelConfig::new("password", "service2");
816
817 let server_handle =
818 thread::spawn(move || ShieldChannel::accept(server_stream, &server_config));
819
820 let client_result = ShieldChannel::connect(client_stream, &client_config);
821 let server_result = server_handle.join().unwrap();
822
823 if let (Ok(client), Ok(server)) = (client_result, server_result) {
827 assert_eq!(client.service(), "service1");
829 assert_eq!(server.service(), "service2");
830 }
831 }
832
833 #[test]
834 fn test_unique_ciphertext_per_message() {
835 let (client_stream, server_stream) = MockStream::pair();
836 let config = ChannelConfig::new("password", "service");
837
838 let server_config = config.clone();
839 let server_handle = thread::spawn(move || {
840 let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
841 let msg1 = server.recv().unwrap();
842 let msg2 = server.recv().unwrap();
843 assert_eq!(msg1, msg2);
845 assert_eq!(msg1, b"same message");
846 });
847
848 let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
849
850 client.send(b"same message").unwrap();
852 client.send(b"same message").unwrap();
853
854 server_handle.join().unwrap();
855 }
856
857 #[test]
858 fn test_listener_config() {
859 let config = ChannelConfig::new("password", "service");
860 let listener = ShieldListener::new((), config.clone());
861
862 assert_eq!(listener.config().service(), "service");
863
864 listener.into_inner();
865 }
866
867 #[test]
868 fn test_channel_counters() {
869 let (client_stream, server_stream) = MockStream::pair();
870 let config = ChannelConfig::new("password", "service");
871
872 let server_config = config.clone();
873 let server_handle = thread::spawn(move || {
874 let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
875 assert_eq!(server.messages_sent(), 0);
876 assert_eq!(server.messages_received(), 0);
877
878 let _ = server.recv().unwrap();
879 assert_eq!(server.messages_received(), 1);
880
881 server.send(b"response").unwrap();
882 assert_eq!(server.messages_sent(), 1);
883 });
884
885 let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
886 assert_eq!(client.messages_sent(), 0);
887
888 client.send(b"hello").unwrap();
889 assert_eq!(client.messages_sent(), 1);
890
891 let _ = client.recv().unwrap();
892 assert_eq!(client.messages_received(), 1);
893
894 server_handle.join().unwrap();
895 }
896}