whatsapp_rust/socket/
noise_socket.rs

1use crate::socket::error::{EncryptSendError, Result, SocketError};
2use crate::transport::Transport;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU32, Ordering};
5use tokio::sync::{mpsc, oneshot};
6use tokio::task::JoinHandle;
7use wacore::aes_gcm::{
8    Aes256Gcm,
9    aead::{Aead, AeadInPlace},
10};
11use wacore::handshake::utils::generate_iv;
12
13const INLINE_ENCRYPT_THRESHOLD: usize = 16 * 1024;
14
15/// Result type for send operations, returning both buffers for reuse.
16type SendResult = std::result::Result<(Vec<u8>, Vec<u8>), EncryptSendError>;
17
18/// A job sent to the dedicated sender task.
19struct SendJob {
20    plaintext_buf: Vec<u8>,
21    out_buf: Vec<u8>,
22    response_tx: oneshot::Sender<SendResult>,
23}
24
25pub struct NoiseSocket {
26    read_key: Arc<Aes256Gcm>,
27    read_counter: Arc<AtomicU32>,
28    /// Channel to send jobs to the dedicated sender task.
29    /// Using a channel instead of a mutex avoids blocking callers while
30    /// the current send is in progress - they can enqueue their work and
31    /// await the result without holding a lock.
32    send_job_tx: mpsc::Sender<SendJob>,
33    /// Handle to the sender task. Aborted on drop to prevent resource leaks
34    /// if the task is stuck on a slow/hanging network operation.
35    sender_task_handle: JoinHandle<()>,
36}
37
38impl NoiseSocket {
39    pub fn new(transport: Arc<dyn Transport>, write_key: Aes256Gcm, read_key: Aes256Gcm) -> Self {
40        let write_key = Arc::new(write_key);
41        let read_key = Arc::new(read_key);
42
43        // Create channel for send jobs. Buffer size of 32 allows multiple
44        // callers to enqueue work without blocking on channel capacity.
45        let (send_job_tx, send_job_rx) = mpsc::channel::<SendJob>(32);
46
47        // Spawn the dedicated sender task
48        let transport_clone = transport.clone();
49        let write_key_clone = write_key.clone();
50        let sender_task_handle = tokio::spawn(Self::sender_task(
51            transport_clone,
52            write_key_clone,
53            send_job_rx,
54        ));
55
56        Self {
57            read_key,
58            read_counter: Arc::new(AtomicU32::new(0)),
59            send_job_tx,
60            sender_task_handle,
61        }
62    }
63
64    /// Dedicated sender task that processes send jobs sequentially.
65    /// This ensures frames are sent in counter order without requiring a mutex.
66    /// The task owns the write counter and processes jobs one at a time.
67    async fn sender_task(
68        transport: Arc<dyn Transport>,
69        write_key: Arc<Aes256Gcm>,
70        mut send_job_rx: mpsc::Receiver<SendJob>,
71    ) {
72        let mut write_counter: u32 = 0;
73
74        while let Some(job) = send_job_rx.recv().await {
75            let result = Self::process_send_job(
76                &transport,
77                &write_key,
78                &mut write_counter,
79                job.plaintext_buf,
80                job.out_buf,
81            )
82            .await;
83
84            // Send result back to caller. Ignore error if receiver was dropped.
85            let _ = job.response_tx.send(result);
86        }
87
88        // Channel closed - NoiseSocket was dropped, task exits naturally
89    }
90
91    /// Process a single send job: encrypt and send the message.
92    async fn process_send_job(
93        transport: &Arc<dyn Transport>,
94        write_key: &Arc<Aes256Gcm>,
95        write_counter: &mut u32,
96        mut plaintext_buf: Vec<u8>,
97        mut out_buf: Vec<u8>,
98    ) -> SendResult {
99        let counter = *write_counter;
100        *write_counter = write_counter.wrapping_add(1);
101
102        // For small messages, encrypt in-place in out_buf to avoid allocation
103        if plaintext_buf.len() <= INLINE_ENCRYPT_THRESHOLD {
104            // Copy plaintext to out_buf and encrypt in-place
105            out_buf.clear();
106            out_buf.extend_from_slice(&plaintext_buf);
107            plaintext_buf.clear();
108
109            let iv = generate_iv(counter);
110            if let Err(e) = write_key.encrypt_in_place(iv.as_ref().into(), b"", &mut out_buf) {
111                return Err(EncryptSendError::crypto(
112                    anyhow::anyhow!(e.to_string()),
113                    plaintext_buf,
114                    out_buf,
115                ));
116            }
117
118            // Frame the ciphertext - we need a temporary copy since encode_frame_into
119            // clears the output buffer
120            let ciphertext_len = out_buf.len();
121            plaintext_buf.extend_from_slice(&out_buf);
122            out_buf.clear();
123            if let Err(e) = wacore::framing::encode_frame_into(
124                &plaintext_buf[..ciphertext_len],
125                None,
126                &mut out_buf,
127            ) {
128                plaintext_buf.clear();
129                return Err(EncryptSendError::framing(e, plaintext_buf, out_buf));
130            }
131            plaintext_buf.clear();
132        } else {
133            // Offload larger messages to a blocking thread
134            let write_key = write_key.clone();
135
136            let plaintext_arc = Arc::new(plaintext_buf);
137            let plaintext_arc_for_task = plaintext_arc.clone();
138
139            let spawn_result = tokio::task::spawn_blocking(move || {
140                let iv = generate_iv(counter);
141                write_key.encrypt(iv.as_ref().into(), &plaintext_arc_for_task[..])
142            })
143            .await;
144
145            plaintext_buf = Arc::try_unwrap(plaintext_arc).unwrap_or_else(|arc| (*arc).clone());
146
147            let ciphertext = match spawn_result {
148                Ok(Ok(c)) => c,
149                Ok(Err(e)) => {
150                    return Err(EncryptSendError::crypto(
151                        anyhow::anyhow!(e.to_string()),
152                        plaintext_buf,
153                        out_buf,
154                    ));
155                }
156                Err(join_err) => {
157                    return Err(EncryptSendError::join(join_err, plaintext_buf, out_buf));
158                }
159            };
160
161            plaintext_buf.clear();
162            out_buf.clear();
163            if let Err(e) = wacore::framing::encode_frame_into(&ciphertext, None, &mut out_buf) {
164                return Err(EncryptSendError::framing(e, plaintext_buf, out_buf));
165            }
166        }
167
168        if let Err(e) = transport.send(&out_buf).await {
169            return Err(EncryptSendError::transport(e, plaintext_buf, out_buf));
170        }
171
172        out_buf.clear();
173        Ok((plaintext_buf, out_buf))
174    }
175
176    pub async fn encrypt_and_send(&self, plaintext_buf: Vec<u8>, out_buf: Vec<u8>) -> SendResult {
177        let (response_tx, response_rx) = oneshot::channel();
178
179        let job = SendJob {
180            plaintext_buf,
181            out_buf,
182            response_tx,
183        };
184
185        // Send job to the sender task. If channel is closed, sender task has stopped.
186        if let Err(send_err) = self.send_job_tx.send(job).await {
187            // Recover the buffers from the failed send job so caller can reuse them
188            let job = send_err.0;
189            return Err(EncryptSendError::channel_closed(
190                job.plaintext_buf,
191                job.out_buf,
192            ));
193        }
194
195        // Wait for the sender task to process our job and return the result
196        match response_rx.await {
197            Ok(result) => result,
198            Err(_) => {
199                // Sender task dropped without sending a response
200                Err(EncryptSendError::channel_closed(Vec::new(), Vec::new()))
201            }
202        }
203    }
204
205    pub fn decrypt_frame(&self, ciphertext: &[u8]) -> Result<Vec<u8>> {
206        let counter = self.read_counter.fetch_add(1, Ordering::SeqCst);
207        let iv = generate_iv(counter);
208        self.read_key
209            .decrypt(iv.as_ref().into(), ciphertext)
210            .map_err(|e| SocketError::Crypto(e.to_string()))
211    }
212}
213
214impl Drop for NoiseSocket {
215    fn drop(&mut self) {
216        // Abort the sender task to prevent resource leaks if it's stuck
217        // on a slow/hanging network operation. This ensures cleanup even
218        // if transport.send() never returns.
219        self.sender_task_handle.abort();
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226    use wacore::aes_gcm::{Aes256Gcm, KeyInit};
227
228    #[tokio::test]
229    async fn test_encrypt_and_send_returns_both_buffers() {
230        // Create a mock transport
231        let transport = Arc::new(crate::transport::mock::MockTransport);
232
233        // Create dummy keys for testing
234        let key = [0u8; 32];
235        let write_key =
236            Aes256Gcm::new_from_slice(&key).expect("32-byte key should be valid for AES-256-GCM");
237        let read_key =
238            Aes256Gcm::new_from_slice(&key).expect("32-byte key should be valid for AES-256-GCM");
239
240        let socket = NoiseSocket::new(transport, write_key, read_key);
241
242        // Create buffers with some initial capacity
243        let plaintext_buf = Vec::with_capacity(1024);
244        let encrypted_buf = Vec::with_capacity(1024);
245
246        // Store the capacities for verification
247        let plaintext_capacity = plaintext_buf.capacity();
248        let encrypted_capacity = encrypted_buf.capacity();
249
250        // Call encrypt_and_send - this should return both buffers
251        let result = socket.encrypt_and_send(plaintext_buf, encrypted_buf).await;
252
253        assert!(result.is_ok(), "encrypt_and_send should succeed");
254
255        let (returned_plaintext, returned_encrypted) =
256            result.expect("encrypt_and_send result should unwrap after is_ok check");
257
258        // Verify both buffers are returned
259        assert_eq!(
260            returned_plaintext.capacity(),
261            plaintext_capacity,
262            "Plaintext buffer should maintain its capacity"
263        );
264        assert_eq!(
265            returned_encrypted.capacity(),
266            encrypted_capacity,
267            "Encrypted buffer should maintain its capacity"
268        );
269
270        // Verify buffers are cleared
271        assert!(
272            returned_plaintext.is_empty(),
273            "Returned plaintext buffer should be cleared"
274        );
275        assert!(
276            returned_encrypted.is_empty(),
277            "Returned encrypted buffer should be cleared"
278        );
279    }
280
281    #[tokio::test]
282    async fn test_concurrent_sends_maintain_order() {
283        use async_trait::async_trait;
284        use std::sync::Arc;
285        use tokio::sync::Mutex;
286
287        // Create a mock transport that records the order of sends by decrypting
288        // the first byte (which contains the task index)
289        struct RecordingTransport {
290            recorded_order: Arc<Mutex<Vec<u8>>>,
291            read_key: Aes256Gcm,
292            counter: std::sync::atomic::AtomicU32,
293        }
294
295        #[async_trait]
296        impl crate::transport::Transport for RecordingTransport {
297            async fn send(&self, data: &[u8]) -> std::result::Result<(), anyhow::Error> {
298                // Decrypt the data to extract the index (first byte of plaintext)
299                if data.len() > 16 {
300                    // Skip the noise frame header (3 bytes for length)
301                    let ciphertext = &data[3..];
302                    let counter = self
303                        .counter
304                        .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
305                    let iv = super::generate_iv(counter);
306
307                    if let Ok(plaintext) = self.read_key.decrypt(iv.as_ref().into(), ciphertext)
308                        && !plaintext.is_empty()
309                    {
310                        let index = plaintext[0];
311                        let mut order = self.recorded_order.lock().await;
312                        order.push(index);
313                    }
314                }
315                Ok(())
316            }
317
318            async fn disconnect(&self) {}
319        }
320
321        let recorded_order = Arc::new(Mutex::new(Vec::new()));
322        let key = [0u8; 32];
323        let write_key =
324            Aes256Gcm::new_from_slice(&key).expect("32-byte key should be valid for AES-256-GCM");
325        let read_key =
326            Aes256Gcm::new_from_slice(&key).expect("32-byte key should be valid for AES-256-GCM");
327
328        let transport = Arc::new(RecordingTransport {
329            recorded_order: recorded_order.clone(),
330            read_key: Aes256Gcm::new_from_slice(&key)
331                .expect("32-byte key should be valid for AES-256-GCM"),
332            counter: std::sync::atomic::AtomicU32::new(0),
333        });
334
335        let socket = Arc::new(NoiseSocket::new(transport, write_key, read_key));
336
337        // Spawn multiple concurrent sends with their indices
338        let mut handles = Vec::new();
339        for i in 0..10 {
340            let socket = socket.clone();
341            handles.push(tokio::spawn(async move {
342                // Use index as the first byte of plaintext to identify this send
343                let mut plaintext = vec![i as u8];
344                plaintext.extend_from_slice(&[0u8; 99]);
345                let out_buf = Vec::with_capacity(256);
346                socket.encrypt_and_send(plaintext, out_buf).await
347            }));
348        }
349
350        // Wait for all sends to complete
351        for handle in handles {
352            let result = handle.await.expect("task should complete");
353            assert!(result.is_ok(), "All sends should succeed");
354        }
355
356        // Verify all sends completed in FIFO order (0, 1, 2, ..., 9)
357        let order = recorded_order.lock().await;
358        let expected: Vec<u8> = (0..10).collect();
359        assert_eq!(*order, expected, "Sends should maintain FIFO order");
360    }
361}