Skip to main content

shield_core/
channel.rs

1//! Shield Secure Channel - TLS/SSH-like secure transport using symmetric crypto.
2//!
3//! Provides encrypted bidirectional communication with:
4//! - PAKE-based handshake (no certificates needed)
5//! - Forward secrecy via key ratcheting
6//! - Message authentication and replay protection
7//!
8//! # Example
9//!
10//! ```rust,no_run
11//! use shield_core::channel::{ShieldChannel, ChannelConfig};
12//! use std::net::{TcpListener, TcpStream};
13//!
14//! fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
15//!     let config = ChannelConfig::new("shared-secret", "my-service");
16//!
17//!     // Server thread
18//!     let server_config = config.clone();
19//!     let server = std::thread::spawn(move || {
20//!         let listener = TcpListener::bind("127.0.0.1:9876").unwrap();
21//!         let (stream, _) = listener.accept().unwrap();
22//!         let mut ch = ShieldChannel::accept(stream, &server_config).unwrap();
23//!         let msg = ch.recv().unwrap();
24//!         assert_eq!(msg, b"Hello server!");
25//!     });
26//!
27//!     // Client side
28//!     let stream = TcpStream::connect("127.0.0.1:9876")?;
29//!     let mut client = ShieldChannel::connect(stream, &config)?;
30//!     client.send(b"Hello server!")?;
31//!
32//!     server.join().unwrap();
33//!     Ok(())
34//! }
35//! ```
36
37// Crypto block counters are intentionally u32
38#![allow(clippy::cast_possible_truncation)]
39
40use ring::hmac;
41use std::io::{Read, Write};
42use subtle::ConstantTimeEq;
43use zeroize::Zeroize;
44
45use crate::error::{Result, ShieldError};
46use crate::exchange::PAKEExchange;
47use crate::ratchet::RatchetSession;
48
49/// Channel protocol version.
50const PROTOCOL_VERSION: u8 = 1;
51
52/// Absolute maximum message size cap (16 MB).
53const MAX_MESSAGE_SIZE_CAP: usize = 16 * 1024 * 1024;
54
55/// Default maximum message size (1 MB).
56const DEFAULT_MAX_MESSAGE_SIZE: usize = 1024 * 1024;
57
58/// Handshake message types.
59#[repr(u8)]
60#[derive(Clone, Copy)]
61enum HandshakeType {
62    ClientHello = 1,
63    ServerHello = 2,
64    Finished = 3,
65}
66
67/// Channel configuration.
68#[derive(Clone)]
69pub struct ChannelConfig {
70    /// Shared password for PAKE.
71    password: String,
72    /// Service identifier (domain binding).
73    service: String,
74    /// PBKDF2 iterations for key derivation.
75    iterations: u32,
76    /// Handshake timeout in milliseconds.
77    handshake_timeout_ms: u64,
78    /// Maximum message size in bytes (default 1 MB, capped at 16 MB).
79    max_message_size: usize,
80}
81
82impl ChannelConfig {
83    /// Create new channel configuration.
84    ///
85    /// # Arguments
86    /// * `password` - Shared secret between parties
87    /// * `service` - Service identifier for domain separation
88    #[must_use]
89    pub fn new(password: &str, service: &str) -> Self {
90        Self {
91            password: password.to_string(),
92            service: service.to_string(),
93            iterations: PAKEExchange::ITERATIONS,
94            handshake_timeout_ms: 30_000,
95            max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
96        }
97    }
98
99    /// Set custom PBKDF2 iterations.
100    #[must_use]
101    pub fn with_iterations(mut self, iterations: u32) -> Self {
102        self.iterations = iterations;
103        self
104    }
105
106    /// Set handshake timeout.
107    #[must_use]
108    pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
109        self.handshake_timeout_ms = timeout_ms;
110        self
111    }
112
113    /// Set maximum message size in bytes (capped at 16 MB).
114    #[must_use]
115    pub fn with_max_message_size(mut self, size: usize) -> Self {
116        self.max_message_size = size.min(MAX_MESSAGE_SIZE_CAP);
117        self
118    }
119
120    /// Get password (for internal use by async channel).
121    #[cfg(feature = "async")]
122    #[must_use]
123    pub(crate) fn password(&self) -> &str {
124        &self.password
125    }
126
127    /// Get service identifier.
128    #[must_use]
129    pub fn service(&self) -> &str {
130        &self.service
131    }
132
133    /// Get iterations count (for internal use by async channel).
134    #[cfg(feature = "async")]
135    #[must_use]
136    pub(crate) fn iterations(&self) -> u32 {
137        self.iterations
138    }
139
140    /// Get handshake timeout in milliseconds (for internal use by async channel).
141    #[cfg(feature = "async")]
142    #[must_use]
143    pub(crate) fn handshake_timeout_ms(&self) -> u64 {
144        self.handshake_timeout_ms
145    }
146
147    /// Get maximum message size in bytes.
148    #[must_use]
149    pub fn max_message_size(&self) -> usize {
150        self.max_message_size
151    }
152}
153
154/// Handshake state for key exchange.
155struct HandshakeState {
156    salt: [u8; 16],
157    local_contribution: [u8; 32],
158    remote_contribution: Option<[u8; 32]>,
159    is_initiator: bool,
160}
161
162impl Drop for HandshakeState {
163    fn drop(&mut self) {
164        self.salt.zeroize();
165        self.local_contribution.zeroize();
166        if let Some(ref mut remote) = self.remote_contribution {
167            remote.zeroize();
168        }
169    }
170}
171
172impl HandshakeState {
173    fn new(is_initiator: bool) -> Result<Self> {
174        let salt: [u8; 16] = crate::random::random_bytes()?;
175
176        Ok(Self {
177            salt,
178            local_contribution: [0u8; 32],
179            remote_contribution: None,
180            is_initiator,
181        })
182    }
183
184    fn derive_contribution(&mut self, config: &ChannelConfig) {
185        let role = if self.is_initiator {
186            "client"
187        } else {
188            "server"
189        };
190        self.local_contribution =
191            PAKEExchange::derive(&config.password, &self.salt, role, Some(config.iterations));
192    }
193
194    fn compute_session_key(&self, config: &ChannelConfig) -> Result<[u8; 32]> {
195        let remote = self
196            .remote_contribution
197            .ok_or_else(|| ShieldError::ChannelError("handshake incomplete".into()))?;
198
199        // CRITICAL: Include password-derived key in session key computation
200        // This ensures different passwords produce different session keys
201        // even though contributions are exchanged.
202        let base_key = PAKEExchange::combine(&[self.local_contribution, remote]);
203
204        // Mix in the password-derived secret that wasn't exchanged
205        let password_key = PAKEExchange::derive(
206            &config.password,
207            &self.salt,
208            "session",
209            Some(config.iterations),
210        );
211
212        // Final session key = HMAC-SHA256(base_key, password_key)
213        // Using keyed HMAC instead of SHA256(key || data) to prevent length-extension
214        let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, &base_key);
215        let tag = hmac::sign(&hmac_key, &password_key);
216        let mut result = [0u8; 32];
217        result.copy_from_slice(&tag.as_ref()[..32]);
218        Ok(result)
219    }
220}
221
222/// Shield secure channel for encrypted communication.
223///
224/// Provides TLS-like security using only symmetric cryptography:
225/// - PAKE handshake establishes shared key from password
226/// - `RatchetSession` provides forward secrecy
227/// - All messages authenticated with HMAC
228pub struct ShieldChannel<S> {
229    stream: S,
230    session: RatchetSession,
231    service: String,
232    max_message_size: usize,
233}
234
235impl<S: Read + Write> ShieldChannel<S> {
236    /// Connect as client (initiator).
237    ///
238    /// Performs PAKE handshake and establishes encrypted channel.
239    ///
240    /// # Arguments
241    /// * `stream` - Underlying transport (TCP, etc.)
242    /// * `config` - Channel configuration with shared password
243    pub fn connect(mut stream: S, config: &ChannelConfig) -> Result<Self> {
244        let mut state = HandshakeState::new(true)?;
245
246        // Step 1: Send ClientHello (salt)
247        Self::send_handshake(&mut stream, HandshakeType::ClientHello, &state.salt)?;
248
249        // Step 2: Receive ServerHello (server's salt + contribution)
250        let server_hello = Self::recv_handshake(&mut stream, HandshakeType::ServerHello)?;
251        if server_hello.len() != 48 {
252            return Err(ShieldError::ChannelError("invalid ServerHello".into()));
253        }
254
255        // Use server's salt for key derivation
256        state.salt.copy_from_slice(&server_hello[..16]);
257        state.derive_contribution(config);
258
259        let mut remote = [0u8; 32];
260        remote.copy_from_slice(&server_hello[16..48]);
261        state.remote_contribution = Some(remote);
262
263        // Step 3: Send client contribution
264        Self::send_handshake(
265            &mut stream,
266            HandshakeType::Finished,
267            &state.local_contribution,
268        )?;
269
270        // Derive session key and create ratchet
271        let session_key = state.compute_session_key(config)?;
272        let session = RatchetSession::new(&session_key, true);
273
274        // Verify handshake with confirmation message
275        Self::send_confirmation(&mut stream, &session_key, true)?;
276        Self::verify_confirmation(&mut stream, &session_key, false)?;
277
278        Ok(Self {
279            stream,
280            session,
281            service: config.service.clone(),
282            max_message_size: config.max_message_size,
283        })
284    }
285
286    /// Accept connection as server.
287    ///
288    /// Waits for client handshake and establishes encrypted channel.
289    ///
290    /// # Arguments
291    /// * `stream` - Underlying transport (TCP, etc.)
292    /// * `config` - Channel configuration with shared password
293    pub fn accept(mut stream: S, config: &ChannelConfig) -> Result<Self> {
294        let mut state = HandshakeState::new(false)?;
295
296        // Step 1: Receive ClientHello (client's proposed salt)
297        let client_hello = Self::recv_handshake(&mut stream, HandshakeType::ClientHello)?;
298        if client_hello.len() != 16 {
299            return Err(ShieldError::ChannelError("invalid ClientHello".into()));
300        }
301
302        // Mix client salt with server salt for freshness
303        for (i, &b) in client_hello.iter().enumerate() {
304            state.salt[i] ^= b;
305        }
306
307        state.derive_contribution(config);
308
309        // Step 2: Send ServerHello (final salt + server contribution)
310        let mut server_hello = Vec::with_capacity(48);
311        server_hello.extend_from_slice(&state.salt);
312        server_hello.extend_from_slice(&state.local_contribution);
313        Self::send_handshake(&mut stream, HandshakeType::ServerHello, &server_hello)?;
314
315        // Step 3: Receive client contribution
316        let client_finished = Self::recv_handshake(&mut stream, HandshakeType::Finished)?;
317        if client_finished.len() != 32 {
318            return Err(ShieldError::ChannelError("invalid Finished".into()));
319        }
320
321        let mut remote = [0u8; 32];
322        remote.copy_from_slice(&client_finished);
323        state.remote_contribution = Some(remote);
324
325        // Derive session key and create ratchet
326        let session_key = state.compute_session_key(config)?;
327        let session = RatchetSession::new(&session_key, false);
328
329        // Verify handshake with confirmation message
330        Self::verify_confirmation(&mut stream, &session_key, true)?;
331        Self::send_confirmation(&mut stream, &session_key, false)?;
332
333        Ok(Self {
334            stream,
335            session,
336            service: config.service.clone(),
337            max_message_size: config.max_message_size,
338        })
339    }
340
341    /// Send encrypted message.
342    ///
343    /// Message is encrypted with current ratchet key, then key advances.
344    pub fn send(&mut self, data: &[u8]) -> Result<()> {
345        if data.len() > self.max_message_size {
346            return Err(ShieldError::ChannelError(format!(
347                "message too large: {} > {}",
348                data.len(),
349                self.max_message_size
350            )));
351        }
352
353        let encrypted = self.session.encrypt(data)?;
354        Self::write_frame(&mut self.stream, &encrypted, self.max_message_size)
355    }
356
357    /// Receive and decrypt message.
358    ///
359    /// Verifies authentication and advances receive ratchet.
360    pub fn recv(&mut self) -> Result<Vec<u8>> {
361        let encrypted = Self::read_frame(&mut self.stream, self.max_message_size)?;
362        self.session.decrypt(&encrypted)
363    }
364
365    /// Get service identifier.
366    #[must_use]
367    pub fn service(&self) -> &str {
368        &self.service
369    }
370
371    /// Get send message count.
372    #[must_use]
373    pub fn messages_sent(&self) -> u64 {
374        self.session.send_counter()
375    }
376
377    /// Get receive message count.
378    #[must_use]
379    pub fn messages_received(&self) -> u64 {
380        self.session.recv_counter()
381    }
382
383    /// Get underlying stream (for shutdown, etc.)
384    pub fn into_inner(self) -> S {
385        self.stream
386    }
387
388    // --- Internal handshake helpers ---
389
390    fn send_handshake(stream: &mut S, msg_type: HandshakeType, data: &[u8]) -> Result<()> {
391        let mut frame = Vec::with_capacity(1 + 1 + 2 + data.len());
392        frame.push(PROTOCOL_VERSION);
393        frame.push(msg_type as u8);
394        frame.extend_from_slice(&(data.len() as u16).to_be_bytes());
395        frame.extend_from_slice(data);
396
397        stream
398            .write_all(&frame)
399            .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
400        stream
401            .flush()
402            .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
403        Ok(())
404    }
405
406    fn recv_handshake(stream: &mut S, expected_type: HandshakeType) -> Result<Vec<u8>> {
407        let mut header = [0u8; 4];
408        stream
409            .read_exact(&mut header)
410            .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
411
412        if header[0] != PROTOCOL_VERSION {
413            return Err(ShieldError::ChannelError(format!(
414                "unsupported protocol version: {}",
415                header[0]
416            )));
417        }
418
419        if header[1] != expected_type as u8 {
420            return Err(ShieldError::ChannelError(format!(
421                "unexpected message type: expected {}, got {}",
422                expected_type as u8, header[1]
423            )));
424        }
425
426        let len = u16::from_be_bytes([header[2], header[3]]) as usize;
427        if len > 1024 {
428            return Err(ShieldError::ChannelError(
429                "handshake message too large".into(),
430            ));
431        }
432
433        let mut data = vec![0u8; len];
434        stream
435            .read_exact(&mut data)
436            .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
437
438        Ok(data)
439    }
440
441    fn send_confirmation(stream: &mut S, session_key: &[u8; 32], is_client: bool) -> Result<()> {
442        let label = if is_client {
443            b"client-confirm"
444        } else {
445            b"server-confirm"
446        };
447        let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, session_key);
448        let confirm = hmac::sign(&hmac_key, label);
449
450        Self::write_frame(stream, &confirm.as_ref()[..16], DEFAULT_MAX_MESSAGE_SIZE)
451    }
452
453    fn verify_confirmation(
454        stream: &mut S,
455        session_key: &[u8; 32],
456        expect_client: bool,
457    ) -> Result<()> {
458        let received = Self::read_frame(stream, DEFAULT_MAX_MESSAGE_SIZE)?;
459        if received.len() != 16 {
460            return Err(ShieldError::ChannelError("invalid confirmation".into()));
461        }
462
463        let label = if expect_client {
464            b"client-confirm"
465        } else {
466            b"server-confirm"
467        };
468        let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, session_key);
469        let expected = hmac::sign(&hmac_key, label);
470
471        if received.ct_eq(&expected.as_ref()[..16]).unwrap_u8() != 1 {
472            return Err(ShieldError::AuthenticationFailed);
473        }
474
475        Ok(())
476    }
477
478    // --- Frame helpers ---
479
480    fn write_frame(stream: &mut S, data: &[u8], max_size: usize) -> Result<()> {
481        if data.len() > max_size {
482            return Err(ShieldError::ChannelError(format!(
483                "frame too large to send: {} > {max_size}",
484                data.len()
485            )));
486        }
487        let len = data.len() as u32;
488        stream
489            .write_all(&len.to_be_bytes())
490            .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
491        stream
492            .write_all(data)
493            .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
494        stream
495            .flush()
496            .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
497        Ok(())
498    }
499
500    fn read_frame(stream: &mut S, max_size: usize) -> Result<Vec<u8>> {
501        let mut len_buf = [0u8; 4];
502        stream
503            .read_exact(&mut len_buf)
504            .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
505
506        let len = u32::from_be_bytes(len_buf) as usize;
507        if len > max_size {
508            return Err(ShieldError::ChannelError(format!(
509                "frame too large: {len} > {max_size}"
510            )));
511        }
512
513        let mut data = vec![0u8; len];
514        stream
515            .read_exact(&mut data)
516            .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
517
518        Ok(data)
519    }
520}
521
522impl ShieldChannel<std::net::TcpStream> {
523    /// Connect as client with handshake timeout enforcement.
524    ///
525    /// Sets socket read/write timeouts during handshake, then clears them
526    /// so post-handshake messaging is not affected.
527    pub fn connect_tcp(stream: std::net::TcpStream, config: &ChannelConfig) -> Result<Self> {
528        let timeout = std::time::Duration::from_millis(config.handshake_timeout_ms);
529        stream
530            .set_read_timeout(Some(timeout))
531            .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
532        stream
533            .set_write_timeout(Some(timeout))
534            .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
535
536        match Self::connect(stream, config) {
537            Ok(channel) => {
538                channel.stream.set_read_timeout(None).ok();
539                channel.stream.set_write_timeout(None).ok();
540                Ok(channel)
541            }
542            Err(e) => Err(e),
543        }
544    }
545
546    /// Accept connection as server with handshake timeout enforcement.
547    ///
548    /// Sets socket read/write timeouts during handshake, then clears them
549    /// so post-handshake messaging is not affected.
550    pub fn accept_tcp(stream: std::net::TcpStream, config: &ChannelConfig) -> Result<Self> {
551        let timeout = std::time::Duration::from_millis(config.handshake_timeout_ms);
552        stream
553            .set_read_timeout(Some(timeout))
554            .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
555        stream
556            .set_write_timeout(Some(timeout))
557            .map_err(|e| ShieldError::ChannelError(e.to_string()))?;
558
559        match Self::accept(stream, config) {
560            Ok(channel) => {
561                channel.stream.set_read_timeout(None).ok();
562                channel.stream.set_write_timeout(None).ok();
563                Ok(channel)
564            }
565            Err(e) => Err(e),
566        }
567    }
568}
569
570/// Channel listener for accepting multiple connections.
571pub struct ShieldListener<L> {
572    listener: L,
573    config: ChannelConfig,
574}
575
576impl<L> ShieldListener<L> {
577    /// Create a new listener with the given configuration.
578    #[must_use]
579    pub fn new(listener: L, config: ChannelConfig) -> Self {
580        Self { listener, config }
581    }
582
583    /// Get the underlying listener.
584    pub fn into_inner(self) -> L {
585        self.listener
586    }
587
588    /// Get a reference to the configuration.
589    #[must_use]
590    pub fn config(&self) -> &ChannelConfig {
591        &self.config
592    }
593}
594
595#[cfg(test)]
596mod tests {
597    use super::*;
598    use std::sync::mpsc;
599    use std::thread;
600
601    /// Mock bidirectional stream for testing.
602    struct MockStream {
603        tx: mpsc::Sender<u8>,
604        rx: mpsc::Receiver<u8>,
605    }
606
607    impl MockStream {
608        fn pair() -> (Self, Self) {
609            let (tx1, rx1) = mpsc::channel();
610            let (tx2, rx2) = mpsc::channel();
611
612            (Self { tx: tx1, rx: rx2 }, Self { tx: tx2, rx: rx1 })
613        }
614    }
615
616    impl Read for MockStream {
617        fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
618            for byte in buf.iter_mut() {
619                *byte = self.rx.recv().map_err(|_| {
620                    std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "channel closed")
621                })?;
622            }
623            Ok(buf.len())
624        }
625    }
626
627    impl Write for MockStream {
628        fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
629            for &byte in buf {
630                self.tx.send(byte).map_err(|_| {
631                    std::io::Error::new(std::io::ErrorKind::BrokenPipe, "channel closed")
632                })?;
633            }
634            Ok(buf.len())
635        }
636
637        fn flush(&mut self) -> std::io::Result<()> {
638            Ok(())
639        }
640    }
641
642    #[test]
643    fn test_channel_handshake() {
644        let (client_stream, server_stream) = MockStream::pair();
645        let config = ChannelConfig::new("test-password", "test.service");
646
647        let server_config = config.clone();
648        let server_handle =
649            thread::spawn(move || ShieldChannel::accept(server_stream, &server_config));
650
651        let client = ShieldChannel::connect(client_stream, &config).unwrap();
652        let server = server_handle.join().unwrap().unwrap();
653
654        assert_eq!(client.service(), "test.service");
655        assert_eq!(server.service(), "test.service");
656    }
657
658    #[test]
659    fn test_channel_message_exchange() {
660        let (client_stream, server_stream) = MockStream::pair();
661        let config = ChannelConfig::new("secret", "messaging.app");
662
663        let server_config = config.clone();
664        let server_handle = thread::spawn(move || {
665            let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
666            let msg = server.recv().unwrap();
667            server.send(b"Hello client!").unwrap();
668            msg
669        });
670
671        let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
672        client.send(b"Hello server!").unwrap();
673        let response = client.recv().unwrap();
674
675        let server_received = server_handle.join().unwrap();
676
677        assert_eq!(server_received, b"Hello server!");
678        assert_eq!(response, b"Hello client!");
679    }
680
681    #[test]
682    fn test_channel_forward_secrecy() {
683        let (client_stream, server_stream) = MockStream::pair();
684        let config = ChannelConfig::new("password", "service");
685
686        let server_config = config.clone();
687        let server_handle = thread::spawn(move || {
688            let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
689            for _ in 0..3 {
690                let _ = server.recv().unwrap();
691            }
692            server.messages_received()
693        });
694
695        let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
696
697        // Send multiple messages
698        client.send(b"message 1").unwrap();
699        client.send(b"message 2").unwrap();
700        client.send(b"message 3").unwrap();
701
702        assert_eq!(client.messages_sent(), 3);
703
704        let server_count = server_handle.join().unwrap();
705        assert_eq!(server_count, 3);
706    }
707
708    #[test]
709    fn test_channel_wrong_password() {
710        let (client_stream, server_stream) = MockStream::pair();
711        let client_config = ChannelConfig::new("password1", "service");
712        let server_config = ChannelConfig::new("password2", "service");
713
714        let server_handle =
715            thread::spawn(move || ShieldChannel::accept(server_stream, &server_config));
716
717        let client_result = ShieldChannel::connect(client_stream, &client_config);
718        let server_result = server_handle.join().unwrap();
719
720        // At least one side should fail authentication
721        assert!(client_result.is_err() || server_result.is_err());
722    }
723
724    #[test]
725    fn test_config_builder() {
726        let config = ChannelConfig::new("password", "service")
727            .with_iterations(100_000)
728            .with_timeout(5_000);
729
730        assert_eq!(config.iterations, 100_000);
731        assert_eq!(config.handshake_timeout_ms, 5_000);
732    }
733
734    #[test]
735    fn test_empty_message() {
736        let (client_stream, server_stream) = MockStream::pair();
737        let config = ChannelConfig::new("password", "service");
738
739        let server_config = config.clone();
740        let server_handle = thread::spawn(move || {
741            let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
742            server.recv().unwrap()
743        });
744
745        let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
746        client.send(b"").unwrap();
747
748        let received = server_handle.join().unwrap();
749        assert_eq!(received, b"");
750    }
751
752    #[test]
753    fn test_large_message() {
754        let (client_stream, server_stream) = MockStream::pair();
755        let config = ChannelConfig::new("password", "service");
756
757        // 64KB message
758        let large_data: Vec<u8> = (0..65536_u32).map(|i| (i % 256) as u8).collect();
759
760        let server_config = config.clone();
761        let expected_data = large_data.clone();
762        let server_handle = thread::spawn(move || {
763            let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
764            let received = server.recv().unwrap();
765            assert_eq!(received.len(), expected_data.len());
766            assert_eq!(received, expected_data);
767        });
768
769        let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
770        client.send(&large_data).unwrap();
771
772        server_handle.join().unwrap();
773    }
774
775    #[test]
776    fn test_bidirectional_exchange() {
777        let (client_stream, server_stream) = MockStream::pair();
778        let config = ChannelConfig::new("password", "service");
779
780        let server_config = config.clone();
781        let server_handle = thread::spawn(move || {
782            let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
783
784            // Server sends first
785            server.send(b"Server says hello").unwrap();
786
787            // Then receives
788            let msg = server.recv().unwrap();
789            assert_eq!(msg, b"Client responds");
790
791            // Send another
792            server.send(b"Server acknowledges").unwrap();
793        });
794
795        let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
796
797        // Client receives first
798        let msg1 = client.recv().unwrap();
799        assert_eq!(msg1, b"Server says hello");
800
801        // Client responds
802        client.send(b"Client responds").unwrap();
803
804        // Client receives acknowledgment
805        let msg2 = client.recv().unwrap();
806        assert_eq!(msg2, b"Server acknowledges");
807
808        server_handle.join().unwrap();
809    }
810
811    #[test]
812    fn test_different_services_same_password() {
813        let (client_stream, server_stream) = MockStream::pair();
814        let client_config = ChannelConfig::new("password", "service1");
815        let server_config = ChannelConfig::new("password", "service2");
816
817        let server_handle =
818            thread::spawn(move || ShieldChannel::accept(server_stream, &server_config));
819
820        let client_result = ShieldChannel::connect(client_stream, &client_config);
821        let server_result = server_handle.join().unwrap();
822
823        // Different services should still connect (service is metadata, not part of key)
824        // But they will have different session keys due to different service in PAKE
825        // This test verifies the behavior - adjust based on actual design
826        if let (Ok(client), Ok(server)) = (client_result, server_result) {
827            // If both succeed, verify services are different
828            assert_eq!(client.service(), "service1");
829            assert_eq!(server.service(), "service2");
830        }
831    }
832
833    #[test]
834    fn test_unique_ciphertext_per_message() {
835        let (client_stream, server_stream) = MockStream::pair();
836        let config = ChannelConfig::new("password", "service");
837
838        let server_config = config.clone();
839        let server_handle = thread::spawn(move || {
840            let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
841            let msg1 = server.recv().unwrap();
842            let msg2 = server.recv().unwrap();
843            // Same plaintext should decrypt correctly
844            assert_eq!(msg1, msg2);
845            assert_eq!(msg1, b"same message");
846        });
847
848        let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
849
850        // Send same message twice - due to ratcheting, ciphertext will differ
851        client.send(b"same message").unwrap();
852        client.send(b"same message").unwrap();
853
854        server_handle.join().unwrap();
855    }
856
857    #[test]
858    fn test_listener_config() {
859        let config = ChannelConfig::new("password", "service");
860        let listener = ShieldListener::new((), config.clone());
861
862        assert_eq!(listener.config().service(), "service");
863
864        listener.into_inner();
865    }
866
867    #[test]
868    fn test_channel_counters() {
869        let (client_stream, server_stream) = MockStream::pair();
870        let config = ChannelConfig::new("password", "service");
871
872        let server_config = config.clone();
873        let server_handle = thread::spawn(move || {
874            let mut server = ShieldChannel::accept(server_stream, &server_config).unwrap();
875            assert_eq!(server.messages_sent(), 0);
876            assert_eq!(server.messages_received(), 0);
877
878            let _ = server.recv().unwrap();
879            assert_eq!(server.messages_received(), 1);
880
881            server.send(b"response").unwrap();
882            assert_eq!(server.messages_sent(), 1);
883        });
884
885        let mut client = ShieldChannel::connect(client_stream, &config).unwrap();
886        assert_eq!(client.messages_sent(), 0);
887
888        client.send(b"hello").unwrap();
889        assert_eq!(client.messages_sent(), 1);
890
891        let _ = client.recv().unwrap();
892        assert_eq!(client.messages_received(), 1);
893
894        server_handle.join().unwrap();
895    }
896}