Skip to main content

whatsapp_rust/socket/
noise_socket.rs

1use crate::socket::error::{EncryptSendError, Result, SocketError};
2use crate::transport::Transport;
3use async_channel;
4use futures::channel::oneshot;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU32, Ordering};
7use wacore::handshake::NoiseCipher;
8use wacore::runtime::{AbortHandle, Runtime};
9
10const INLINE_ENCRYPT_THRESHOLD: usize = 16 * 1024;
11
12/// Result type for send operations.
13type SendResult = std::result::Result<(), EncryptSendError>;
14
15/// A job sent to the dedicated sender task.
16struct SendJob {
17    plaintext_buf: Vec<u8>,
18    out_buf: Vec<u8>,
19    response_tx: oneshot::Sender<SendResult>,
20}
21
22pub struct NoiseSocket {
23    #[allow(dead_code)] // Kept for potential future spawns
24    runtime: Arc<dyn Runtime>,
25    read_key: Arc<NoiseCipher>,
26    read_counter: Arc<AtomicU32>,
27    /// Channel to send jobs to the dedicated sender task.
28    /// Using a channel instead of a mutex avoids blocking callers while
29    /// the current send is in progress - they can enqueue their work and
30    /// await the result without holding a lock.
31    send_job_tx: async_channel::Sender<SendJob>,
32    /// Handle to the sender task. Aborted on drop to prevent resource leaks
33    /// if the task is stuck on a slow/hanging network operation.
34    _sender_task_handle: AbortHandle,
35}
36
37impl NoiseSocket {
38    pub fn new(
39        runtime: Arc<dyn Runtime>,
40        transport: Arc<dyn Transport>,
41        write_key: NoiseCipher,
42        read_key: NoiseCipher,
43    ) -> Self {
44        let write_key = Arc::new(write_key);
45        let read_key = Arc::new(read_key);
46
47        // Create channel for send jobs. Buffer size of 32 allows multiple
48        // callers to enqueue work without blocking on channel capacity.
49        let (send_job_tx, send_job_rx) = async_channel::bounded::<SendJob>(32);
50
51        // Spawn the dedicated sender task
52        let transport_clone = transport.clone();
53        let write_key_clone = write_key.clone();
54        let rt_clone = runtime.clone();
55        let sender_task_handle = runtime.spawn(Box::pin(Self::sender_task(
56            rt_clone,
57            transport_clone,
58            write_key_clone,
59            send_job_rx,
60        )));
61
62        Self {
63            runtime,
64            read_key,
65            read_counter: Arc::new(AtomicU32::new(0)),
66            send_job_tx,
67            _sender_task_handle: sender_task_handle,
68        }
69    }
70
71    /// Dedicated sender task that processes send jobs sequentially.
72    /// This ensures frames are sent in counter order without requiring a mutex.
73    /// The task owns the write counter and processes jobs one at a time.
74    async fn sender_task(
75        runtime: Arc<dyn Runtime>,
76        transport: Arc<dyn Transport>,
77        write_key: Arc<NoiseCipher>,
78        send_job_rx: async_channel::Receiver<SendJob>,
79    ) {
80        let mut write_counter: u32 = 0;
81
82        while let Ok(job) = send_job_rx.recv().await {
83            let result = Self::process_send_job(
84                &runtime,
85                &transport,
86                &write_key,
87                &mut write_counter,
88                job.plaintext_buf,
89                job.out_buf,
90            )
91            .await;
92
93            // Send result back to caller. Ignore error if receiver was dropped.
94            let _ = job.response_tx.send(result);
95        }
96
97        // Channel closed - NoiseSocket was dropped, task exits naturally
98    }
99
100    /// Process a single send job: encrypt and send the message.
101    async fn process_send_job(
102        runtime: &Arc<dyn Runtime>,
103        transport: &Arc<dyn Transport>,
104        write_key: &Arc<NoiseCipher>,
105        write_counter: &mut u32,
106        mut plaintext_buf: Vec<u8>,
107        mut out_buf: Vec<u8>,
108    ) -> SendResult {
109        let counter = *write_counter;
110
111        // For small messages, encrypt plaintext_buf in-place then frame into out_buf.
112        // This avoids the previous triple-copy pattern (plaintext→out→plaintext→out).
113        if plaintext_buf.len() <= INLINE_ENCRYPT_THRESHOLD {
114            if let Err(e) = write_key.encrypt_in_place_with_counter(counter, &mut plaintext_buf) {
115                return Err(EncryptSendError::crypto(anyhow::anyhow!(e.to_string())));
116            }
117
118            // Frame the ciphertext from plaintext_buf into out_buf (single copy)
119            out_buf.clear();
120            if let Err(e) = wacore::framing::encode_frame_into(&plaintext_buf, None, &mut out_buf) {
121                return Err(EncryptSendError::framing(e));
122            }
123        } else {
124            // Offload larger messages to a blocking thread
125            let write_key = write_key.clone();
126
127            let plaintext_arc = Arc::new(plaintext_buf);
128            let plaintext_arc_for_task = plaintext_arc.clone();
129
130            let encrypt_result = wacore::runtime::blocking(&**runtime, move || {
131                write_key.encrypt_with_counter(counter, &plaintext_arc_for_task[..])
132            })
133            .await;
134
135            // Recover ownership so the buffer is dropped at end of scope
136            plaintext_buf = Arc::try_unwrap(plaintext_arc).unwrap_or_else(|arc| (*arc).clone());
137            drop(plaintext_buf);
138
139            let ciphertext = match encrypt_result {
140                Ok(c) => c,
141                Err(e) => {
142                    return Err(EncryptSendError::crypto(anyhow::anyhow!(e.to_string())));
143                }
144            };
145
146            out_buf.clear();
147            if let Err(e) = wacore::framing::encode_frame_into(&ciphertext, None, &mut out_buf) {
148                return Err(EncryptSendError::framing(e));
149            }
150        }
151
152        if let Err(e) = transport.send(out_buf).await {
153            return Err(EncryptSendError::transport(e));
154        }
155
156        // Only advance the counter after the encrypted frame was successfully sent.
157        // If transport.send() fails, we can retry with the same counter value.
158        *write_counter = write_counter.wrapping_add(1);
159
160        Ok(())
161    }
162
163    pub async fn encrypt_and_send(&self, plaintext_buf: Vec<u8>, out_buf: Vec<u8>) -> SendResult {
164        let (response_tx, response_rx) = oneshot::channel();
165
166        let job = SendJob {
167            plaintext_buf,
168            out_buf,
169            response_tx,
170        };
171
172        // Send job to the sender task. If channel is closed, sender task has stopped.
173        if let Err(_send_err) = self.send_job_tx.send(job).await {
174            return Err(EncryptSendError::channel_closed());
175        }
176
177        // Wait for the sender task to process our job and return the result
178        match response_rx.await {
179            Ok(result) => result,
180            Err(_) => {
181                // Sender task dropped without sending a response
182                Err(EncryptSendError::channel_closed())
183            }
184        }
185    }
186
187    pub fn decrypt_frame(&self, ciphertext: &[u8]) -> Result<Vec<u8>> {
188        let counter = self.read_counter.fetch_add(1, Ordering::SeqCst);
189        self.read_key
190            .decrypt_with_counter(counter, ciphertext)
191            .map_err(|e| SocketError::Crypto(e.to_string()))
192    }
193}
194
195// AbortHandle aborts the sender task on drop automatically, so no manual
196// Drop impl is needed — the `sender_task_handle` field's own Drop does the work.
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[tokio::test]
203    async fn test_encrypt_and_send_succeeds() {
204        let transport = Arc::new(crate::transport::mock::MockTransport);
205
206        let key = [0u8; 32];
207        let write_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
208        let read_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
209
210        let socket = NoiseSocket::new(
211            Arc::new(crate::runtime_impl::TokioRuntime),
212            transport,
213            write_key,
214            read_key,
215        );
216
217        let plaintext_buf = Vec::with_capacity(1024);
218        let encrypted_buf = Vec::with_capacity(1024);
219
220        let result = socket.encrypt_and_send(plaintext_buf, encrypted_buf).await;
221        assert!(result.is_ok(), "encrypt_and_send should succeed");
222    }
223
224    #[tokio::test]
225    async fn test_concurrent_sends_maintain_order() {
226        use async_lock::Mutex;
227        use async_trait::async_trait;
228        use std::sync::Arc;
229
230        // Create a mock transport that records the order of sends by decrypting
231        // the first byte (which contains the task index)
232        struct RecordingTransport {
233            recorded_order: Arc<Mutex<Vec<u8>>>,
234            read_key: NoiseCipher,
235            counter: std::sync::atomic::AtomicU32,
236        }
237
238        #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
239        #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
240        impl crate::transport::Transport for RecordingTransport {
241            async fn send(&self, data: Vec<u8>) -> std::result::Result<(), anyhow::Error> {
242                // Decrypt the data to extract the index (first byte of plaintext)
243                if data.len() > 16 {
244                    // Skip the noise frame header (3 bytes for length)
245                    let ciphertext = &data[3..];
246                    let counter = self
247                        .counter
248                        .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
249
250                    if let Ok(plaintext) = self.read_key.decrypt_with_counter(counter, ciphertext)
251                        && !plaintext.is_empty()
252                    {
253                        let index = plaintext[0];
254                        let mut order = self.recorded_order.lock().await;
255                        order.push(index);
256                    }
257                }
258                Ok(())
259            }
260
261            async fn disconnect(&self) {}
262        }
263
264        let recorded_order = Arc::new(Mutex::new(Vec::new()));
265        let key = [0u8; 32];
266        let write_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
267        let read_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
268
269        let transport = Arc::new(RecordingTransport {
270            recorded_order: recorded_order.clone(),
271            read_key: NoiseCipher::new(&key).expect("32-byte key should be valid"),
272            counter: std::sync::atomic::AtomicU32::new(0),
273        });
274
275        let socket = Arc::new(NoiseSocket::new(
276            Arc::new(crate::runtime_impl::TokioRuntime),
277            transport,
278            write_key,
279            read_key,
280        ));
281
282        // Spawn multiple concurrent sends with their indices
283        let mut handles = Vec::new();
284        for i in 0..10 {
285            let socket = socket.clone();
286            handles.push(tokio::spawn(async move {
287                // Use index as the first byte of plaintext to identify this send
288                let mut plaintext = vec![i as u8];
289                plaintext.extend_from_slice(&[0u8; 99]);
290                let out_buf = Vec::with_capacity(256);
291                socket.encrypt_and_send(plaintext, out_buf).await
292            }));
293        }
294
295        // Wait for all sends to complete
296        for handle in handles {
297            let result = handle.await.expect("task should complete");
298            assert!(result.is_ok(), "All sends should succeed");
299        }
300
301        // Verify all sends completed in FIFO order (0, 1, 2, ..., 9)
302        let order = recorded_order.lock().await;
303        let expected: Vec<u8> = (0..10).collect();
304        assert_eq!(*order, expected, "Sends should maintain FIFO order");
305    }
306
307    /// Tests that the encrypted buffer sizing formula (plaintext.len() + 32) is sufficient.
308    /// This verifies the optimization in client.rs that sizes the buffer based on payload.
309    #[tokio::test]
310    async fn test_encrypted_buffer_sizing_is_sufficient() {
311        use async_trait::async_trait;
312        use std::sync::Arc;
313        use std::sync::atomic::{AtomicUsize, Ordering};
314
315        // Transport that records the actual encrypted data size
316        struct SizeRecordingTransport {
317            last_size: Arc<AtomicUsize>,
318        }
319
320        #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
321        #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
322        impl crate::transport::Transport for SizeRecordingTransport {
323            async fn send(&self, data: Vec<u8>) -> std::result::Result<(), anyhow::Error> {
324                self.last_size.store(data.len(), Ordering::SeqCst);
325                Ok(())
326            }
327            async fn disconnect(&self) {}
328        }
329
330        let last_size = Arc::new(AtomicUsize::new(0));
331        let transport = Arc::new(SizeRecordingTransport {
332            last_size: last_size.clone(),
333        });
334
335        let key = [0u8; 32];
336        let write_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
337        let read_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
338
339        let socket = NoiseSocket::new(
340            Arc::new(crate::runtime_impl::TokioRuntime),
341            transport,
342            write_key,
343            read_key,
344        );
345
346        // Test various payload sizes: tiny, small, medium, large, very large
347        let test_sizes = [0, 1, 50, 100, 500, 1000, 1024, 2000, 5000, 16384, 20000];
348
349        for size in test_sizes {
350            let plaintext = vec![0xABu8; size];
351            // This is the formula used in client.rs
352            let buffer_capacity = plaintext.len() + 32;
353            let encrypted_buf = Vec::with_capacity(buffer_capacity);
354
355            let result = socket
356                .encrypt_and_send(plaintext.clone(), encrypted_buf)
357                .await;
358
359            assert!(
360                result.is_ok(),
361                "encrypt_and_send should succeed for payload size {}",
362                size
363            );
364
365            let actual_encrypted_size = last_size.load(Ordering::SeqCst);
366
367            // Verify the actual encrypted size fits within our allocated capacity
368            // Encrypted size = plaintext + 16 (AES-GCM tag) + 3 (frame header) = plaintext + 19
369            let expected_max = size + 19;
370            assert_eq!(
371                actual_encrypted_size, expected_max,
372                "Encrypted size for {} byte payload should be {} (got {})",
373                size, expected_max, actual_encrypted_size
374            );
375
376            // Verify our buffer sizing formula provides enough capacity
377            assert!(
378                buffer_capacity >= actual_encrypted_size,
379                "Buffer capacity {} should be >= encrypted size {} for payload size {}",
380                buffer_capacity,
381                actual_encrypted_size,
382                size
383            );
384        }
385    }
386
387    /// Tests edge cases for buffer sizing
388    #[tokio::test]
389    async fn test_encrypted_buffer_sizing_edge_cases() {
390        use async_trait::async_trait;
391        use std::sync::Arc;
392
393        struct NoOpTransport;
394
395        #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
396        #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
397        impl crate::transport::Transport for NoOpTransport {
398            async fn send(&self, _data: Vec<u8>) -> std::result::Result<(), anyhow::Error> {
399                Ok(())
400            }
401            async fn disconnect(&self) {}
402        }
403
404        let transport = Arc::new(NoOpTransport);
405        let key = [0u8; 32];
406        let write_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
407        let read_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
408
409        let socket = NoiseSocket::new(
410            Arc::new(crate::runtime_impl::TokioRuntime),
411            transport,
412            write_key,
413            read_key,
414        );
415
416        // Test empty payload
417        let result = socket
418            .encrypt_and_send(vec![], Vec::with_capacity(32))
419            .await;
420        assert!(result.is_ok(), "Empty payload should encrypt successfully");
421
422        // Test payload at inline threshold boundary (16KB)
423        let at_threshold = vec![0u8; 16 * 1024];
424        let result = socket
425            .encrypt_and_send(at_threshold, Vec::with_capacity(16 * 1024 + 32))
426            .await;
427        assert!(
428            result.is_ok(),
429            "Payload at inline threshold should encrypt successfully"
430        );
431
432        // Test payload just above inline threshold
433        let above_threshold = vec![0u8; 16 * 1024 + 1];
434        let result = socket
435            .encrypt_and_send(above_threshold, Vec::with_capacity(16 * 1024 + 33))
436            .await;
437        assert!(
438            result.is_ok(),
439            "Payload above inline threshold should encrypt successfully"
440        );
441    }
442}