Skip to main content

whatsapp_rust/socket/
noise_socket.rs

1use crate::socket::error::{EncryptSendError, Result, SocketError};
2use crate::transport::Transport;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU32, Ordering};
5use tokio::sync::{mpsc, oneshot};
6use tokio::task::JoinHandle;
7use wacore::handshake::NoiseCipher;
8
9const INLINE_ENCRYPT_THRESHOLD: usize = 16 * 1024;
10
11/// Result type for send operations, returning both buffers for reuse.
12type SendResult = std::result::Result<(Vec<u8>, Vec<u8>), EncryptSendError>;
13
14/// A job sent to the dedicated sender task.
15struct SendJob {
16    plaintext_buf: Vec<u8>,
17    out_buf: Vec<u8>,
18    response_tx: oneshot::Sender<SendResult>,
19}
20
21pub struct NoiseSocket {
22    read_key: Arc<NoiseCipher>,
23    read_counter: Arc<AtomicU32>,
24    /// Channel to send jobs to the dedicated sender task.
25    /// Using a channel instead of a mutex avoids blocking callers while
26    /// the current send is in progress - they can enqueue their work and
27    /// await the result without holding a lock.
28    send_job_tx: mpsc::Sender<SendJob>,
29    /// Handle to the sender task. Aborted on drop to prevent resource leaks
30    /// if the task is stuck on a slow/hanging network operation.
31    sender_task_handle: JoinHandle<()>,
32}
33
34impl NoiseSocket {
35    pub fn new(
36        transport: Arc<dyn Transport>,
37        write_key: NoiseCipher,
38        read_key: NoiseCipher,
39    ) -> Self {
40        let write_key = Arc::new(write_key);
41        let read_key = Arc::new(read_key);
42
43        // Create channel for send jobs. Buffer size of 32 allows multiple
44        // callers to enqueue work without blocking on channel capacity.
45        let (send_job_tx, send_job_rx) = mpsc::channel::<SendJob>(32);
46
47        // Spawn the dedicated sender task
48        let transport_clone = transport.clone();
49        let write_key_clone = write_key.clone();
50        let sender_task_handle = tokio::spawn(Self::sender_task(
51            transport_clone,
52            write_key_clone,
53            send_job_rx,
54        ));
55
56        Self {
57            read_key,
58            read_counter: Arc::new(AtomicU32::new(0)),
59            send_job_tx,
60            sender_task_handle,
61        }
62    }
63
64    /// Dedicated sender task that processes send jobs sequentially.
65    /// This ensures frames are sent in counter order without requiring a mutex.
66    /// The task owns the write counter and processes jobs one at a time.
67    async fn sender_task(
68        transport: Arc<dyn Transport>,
69        write_key: Arc<NoiseCipher>,
70        mut send_job_rx: mpsc::Receiver<SendJob>,
71    ) {
72        let mut write_counter: u32 = 0;
73
74        while let Some(job) = send_job_rx.recv().await {
75            let result = Self::process_send_job(
76                &transport,
77                &write_key,
78                &mut write_counter,
79                job.plaintext_buf,
80                job.out_buf,
81            )
82            .await;
83
84            // Send result back to caller. Ignore error if receiver was dropped.
85            let _ = job.response_tx.send(result);
86        }
87
88        // Channel closed - NoiseSocket was dropped, task exits naturally
89    }
90
91    /// Process a single send job: encrypt and send the message.
92    async fn process_send_job(
93        transport: &Arc<dyn Transport>,
94        write_key: &Arc<NoiseCipher>,
95        write_counter: &mut u32,
96        mut plaintext_buf: Vec<u8>,
97        mut out_buf: Vec<u8>,
98    ) -> SendResult {
99        let counter = *write_counter;
100        *write_counter = write_counter.wrapping_add(1);
101
102        // For small messages, encrypt in-place in out_buf to avoid allocation
103        if plaintext_buf.len() <= INLINE_ENCRYPT_THRESHOLD {
104            // Copy plaintext to out_buf and encrypt in-place
105            out_buf.clear();
106            out_buf.extend_from_slice(&plaintext_buf);
107            plaintext_buf.clear();
108
109            if let Err(e) = write_key.encrypt_in_place_with_counter(counter, &mut out_buf) {
110                return Err(EncryptSendError::crypto(
111                    anyhow::anyhow!(e.to_string()),
112                    plaintext_buf,
113                    out_buf,
114                ));
115            }
116
117            // Frame the ciphertext - we need a temporary copy since encode_frame_into
118            // clears the output buffer
119            let ciphertext_len = out_buf.len();
120            plaintext_buf.extend_from_slice(&out_buf);
121            out_buf.clear();
122            if let Err(e) = wacore::framing::encode_frame_into(
123                &plaintext_buf[..ciphertext_len],
124                None,
125                &mut out_buf,
126            ) {
127                plaintext_buf.clear();
128                return Err(EncryptSendError::framing(e, plaintext_buf, out_buf));
129            }
130            plaintext_buf.clear();
131        } else {
132            // Offload larger messages to a blocking thread
133            let write_key = write_key.clone();
134
135            let plaintext_arc = Arc::new(plaintext_buf);
136            let plaintext_arc_for_task = plaintext_arc.clone();
137
138            let spawn_result = tokio::task::spawn_blocking(move || {
139                write_key.encrypt_with_counter(counter, &plaintext_arc_for_task[..])
140            })
141            .await;
142
143            plaintext_buf = Arc::try_unwrap(plaintext_arc).unwrap_or_else(|arc| (*arc).clone());
144
145            let ciphertext = match spawn_result {
146                Ok(Ok(c)) => c,
147                Ok(Err(e)) => {
148                    return Err(EncryptSendError::crypto(
149                        anyhow::anyhow!(e.to_string()),
150                        plaintext_buf,
151                        out_buf,
152                    ));
153                }
154                Err(join_err) => {
155                    return Err(EncryptSendError::join(join_err, plaintext_buf, out_buf));
156                }
157            };
158
159            plaintext_buf.clear();
160            out_buf.clear();
161            if let Err(e) = wacore::framing::encode_frame_into(&ciphertext, None, &mut out_buf) {
162                return Err(EncryptSendError::framing(e, plaintext_buf, out_buf));
163            }
164        }
165
166        if let Err(e) = transport.send(out_buf).await {
167            return Err(EncryptSendError::transport(e, plaintext_buf, Vec::new()));
168        }
169
170        Ok((plaintext_buf, Vec::new()))
171    }
172
173    pub async fn encrypt_and_send(&self, plaintext_buf: Vec<u8>, out_buf: Vec<u8>) -> SendResult {
174        let (response_tx, response_rx) = oneshot::channel();
175
176        let job = SendJob {
177            plaintext_buf,
178            out_buf,
179            response_tx,
180        };
181
182        // Send job to the sender task. If channel is closed, sender task has stopped.
183        if let Err(send_err) = self.send_job_tx.send(job).await {
184            // Recover the buffers from the failed send job so caller can reuse them
185            let job = send_err.0;
186            return Err(EncryptSendError::channel_closed(
187                job.plaintext_buf,
188                job.out_buf,
189            ));
190        }
191
192        // Wait for the sender task to process our job and return the result
193        match response_rx.await {
194            Ok(result) => result,
195            Err(_) => {
196                // Sender task dropped without sending a response
197                Err(EncryptSendError::channel_closed(Vec::new(), Vec::new()))
198            }
199        }
200    }
201
202    pub fn decrypt_frame(&self, ciphertext: &[u8]) -> Result<Vec<u8>> {
203        let counter = self.read_counter.fetch_add(1, Ordering::SeqCst);
204        self.read_key
205            .decrypt_with_counter(counter, ciphertext)
206            .map_err(|e| SocketError::Crypto(e.to_string()))
207    }
208}
209
210impl Drop for NoiseSocket {
211    fn drop(&mut self) {
212        // Abort the sender task to prevent resource leaks if it's stuck
213        // on a slow/hanging network operation. This ensures cleanup even
214        // if transport.send() never returns.
215        self.sender_task_handle.abort();
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[tokio::test]
224    async fn test_encrypt_and_send_returns_both_buffers() {
225        // Create a mock transport
226        let transport = Arc::new(crate::transport::mock::MockTransport);
227
228        // Create dummy keys for testing
229        let key = [0u8; 32];
230        let write_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
231        let read_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
232
233        let socket = NoiseSocket::new(transport, write_key, read_key);
234
235        let plaintext_buf = Vec::with_capacity(1024);
236        let encrypted_buf = Vec::with_capacity(1024);
237        let plaintext_capacity = plaintext_buf.capacity();
238
239        let result = socket.encrypt_and_send(plaintext_buf, encrypted_buf).await;
240        assert!(result.is_ok(), "encrypt_and_send should succeed");
241
242        let (returned_plaintext, returned_encrypted) = result.unwrap();
243
244        assert_eq!(
245            returned_plaintext.capacity(),
246            plaintext_capacity,
247            "Plaintext buffer should maintain its capacity"
248        );
249        assert!(
250            returned_encrypted.is_empty(),
251            "Encrypted buffer is moved to transport (zero-copy)"
252        );
253        assert!(
254            returned_plaintext.is_empty(),
255            "Returned plaintext buffer should be cleared"
256        );
257    }
258
259    #[tokio::test]
260    async fn test_concurrent_sends_maintain_order() {
261        use async_trait::async_trait;
262        use std::sync::Arc;
263        use tokio::sync::Mutex;
264
265        // Create a mock transport that records the order of sends by decrypting
266        // the first byte (which contains the task index)
267        struct RecordingTransport {
268            recorded_order: Arc<Mutex<Vec<u8>>>,
269            read_key: NoiseCipher,
270            counter: std::sync::atomic::AtomicU32,
271        }
272
273        #[async_trait]
274        impl crate::transport::Transport for RecordingTransport {
275            async fn send(&self, data: Vec<u8>) -> std::result::Result<(), anyhow::Error> {
276                // Decrypt the data to extract the index (first byte of plaintext)
277                if data.len() > 16 {
278                    // Skip the noise frame header (3 bytes for length)
279                    let ciphertext = &data[3..];
280                    let counter = self
281                        .counter
282                        .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
283
284                    if let Ok(plaintext) = self.read_key.decrypt_with_counter(counter, ciphertext)
285                        && !plaintext.is_empty()
286                    {
287                        let index = plaintext[0];
288                        let mut order = self.recorded_order.lock().await;
289                        order.push(index);
290                    }
291                }
292                Ok(())
293            }
294
295            async fn disconnect(&self) {}
296        }
297
298        let recorded_order = Arc::new(Mutex::new(Vec::new()));
299        let key = [0u8; 32];
300        let write_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
301        let read_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
302
303        let transport = Arc::new(RecordingTransport {
304            recorded_order: recorded_order.clone(),
305            read_key: NoiseCipher::new(&key).expect("32-byte key should be valid"),
306            counter: std::sync::atomic::AtomicU32::new(0),
307        });
308
309        let socket = Arc::new(NoiseSocket::new(transport, write_key, read_key));
310
311        // Spawn multiple concurrent sends with their indices
312        let mut handles = Vec::new();
313        for i in 0..10 {
314            let socket = socket.clone();
315            handles.push(tokio::spawn(async move {
316                // Use index as the first byte of plaintext to identify this send
317                let mut plaintext = vec![i as u8];
318                plaintext.extend_from_slice(&[0u8; 99]);
319                let out_buf = Vec::with_capacity(256);
320                socket.encrypt_and_send(plaintext, out_buf).await
321            }));
322        }
323
324        // Wait for all sends to complete
325        for handle in handles {
326            let result = handle.await.expect("task should complete");
327            assert!(result.is_ok(), "All sends should succeed");
328        }
329
330        // Verify all sends completed in FIFO order (0, 1, 2, ..., 9)
331        let order = recorded_order.lock().await;
332        let expected: Vec<u8> = (0..10).collect();
333        assert_eq!(*order, expected, "Sends should maintain FIFO order");
334    }
335
336    /// Tests that the encrypted buffer sizing formula (plaintext.len() + 32) is sufficient.
337    /// This verifies the optimization in client.rs that sizes the buffer based on payload.
338    #[tokio::test]
339    async fn test_encrypted_buffer_sizing_is_sufficient() {
340        use async_trait::async_trait;
341        use std::sync::Arc;
342        use std::sync::atomic::{AtomicUsize, Ordering};
343
344        // Transport that records the actual encrypted data size
345        struct SizeRecordingTransport {
346            last_size: Arc<AtomicUsize>,
347        }
348
349        #[async_trait]
350        impl crate::transport::Transport for SizeRecordingTransport {
351            async fn send(&self, data: Vec<u8>) -> std::result::Result<(), anyhow::Error> {
352                self.last_size.store(data.len(), Ordering::SeqCst);
353                Ok(())
354            }
355            async fn disconnect(&self) {}
356        }
357
358        let last_size = Arc::new(AtomicUsize::new(0));
359        let transport = Arc::new(SizeRecordingTransport {
360            last_size: last_size.clone(),
361        });
362
363        let key = [0u8; 32];
364        let write_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
365        let read_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
366
367        let socket = NoiseSocket::new(transport, write_key, read_key);
368
369        // Test various payload sizes: tiny, small, medium, large, very large
370        let test_sizes = [0, 1, 50, 100, 500, 1000, 1024, 2000, 5000, 16384, 20000];
371
372        for size in test_sizes {
373            let plaintext = vec![0xABu8; size];
374            // This is the formula used in client.rs
375            let buffer_capacity = plaintext.len() + 32;
376            let encrypted_buf = Vec::with_capacity(buffer_capacity);
377
378            let result = socket
379                .encrypt_and_send(plaintext.clone(), encrypted_buf)
380                .await;
381
382            assert!(
383                result.is_ok(),
384                "encrypt_and_send should succeed for payload size {}",
385                size
386            );
387
388            let actual_encrypted_size = last_size.load(Ordering::SeqCst);
389
390            // Verify the actual encrypted size fits within our allocated capacity
391            // Encrypted size = plaintext + 16 (AES-GCM tag) + 3 (frame header) = plaintext + 19
392            let expected_max = size + 19;
393            assert_eq!(
394                actual_encrypted_size, expected_max,
395                "Encrypted size for {} byte payload should be {} (got {})",
396                size, expected_max, actual_encrypted_size
397            );
398
399            // Verify our buffer sizing formula provides enough capacity
400            assert!(
401                buffer_capacity >= actual_encrypted_size,
402                "Buffer capacity {} should be >= encrypted size {} for payload size {}",
403                buffer_capacity,
404                actual_encrypted_size,
405                size
406            );
407        }
408    }
409
410    /// Tests edge cases for buffer sizing
411    #[tokio::test]
412    async fn test_encrypted_buffer_sizing_edge_cases() {
413        use async_trait::async_trait;
414        use std::sync::Arc;
415
416        struct NoOpTransport;
417
418        #[async_trait]
419        impl crate::transport::Transport for NoOpTransport {
420            async fn send(&self, _data: Vec<u8>) -> std::result::Result<(), anyhow::Error> {
421                Ok(())
422            }
423            async fn disconnect(&self) {}
424        }
425
426        let transport = Arc::new(NoOpTransport);
427        let key = [0u8; 32];
428        let write_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
429        let read_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
430
431        let socket = NoiseSocket::new(transport, write_key, read_key);
432
433        // Test empty payload
434        let result = socket
435            .encrypt_and_send(vec![], Vec::with_capacity(32))
436            .await;
437        assert!(result.is_ok(), "Empty payload should encrypt successfully");
438
439        // Test payload at inline threshold boundary (16KB)
440        let at_threshold = vec![0u8; 16 * 1024];
441        let result = socket
442            .encrypt_and_send(at_threshold, Vec::with_capacity(16 * 1024 + 32))
443            .await;
444        assert!(
445            result.is_ok(),
446            "Payload at inline threshold should encrypt successfully"
447        );
448
449        // Test payload just above inline threshold
450        let above_threshold = vec![0u8; 16 * 1024 + 1];
451        let result = socket
452            .encrypt_and_send(above_threshold, Vec::with_capacity(16 * 1024 + 33))
453            .await;
454        assert!(
455            result.is_ok(),
456            "Payload above inline threshold should encrypt successfully"
457        );
458    }
459}