1use crate::socket::error::{EncryptSendError, Result, SocketError};
2use crate::transport::Transport;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU32, Ordering};
5use tokio::sync::{mpsc, oneshot};
6use tokio::task::JoinHandle;
7use wacore::handshake::NoiseCipher;
8
9const INLINE_ENCRYPT_THRESHOLD: usize = 16 * 1024;
10
11type SendResult = std::result::Result<(Vec<u8>, Vec<u8>), EncryptSendError>;
13
14struct SendJob {
16 plaintext_buf: Vec<u8>,
17 out_buf: Vec<u8>,
18 response_tx: oneshot::Sender<SendResult>,
19}
20
21pub struct NoiseSocket {
22 read_key: Arc<NoiseCipher>,
23 read_counter: Arc<AtomicU32>,
24 send_job_tx: mpsc::Sender<SendJob>,
29 sender_task_handle: JoinHandle<()>,
32}
33
34impl NoiseSocket {
35 pub fn new(
36 transport: Arc<dyn Transport>,
37 write_key: NoiseCipher,
38 read_key: NoiseCipher,
39 ) -> Self {
40 let write_key = Arc::new(write_key);
41 let read_key = Arc::new(read_key);
42
43 let (send_job_tx, send_job_rx) = mpsc::channel::<SendJob>(32);
46
47 let transport_clone = transport.clone();
49 let write_key_clone = write_key.clone();
50 let sender_task_handle = tokio::spawn(Self::sender_task(
51 transport_clone,
52 write_key_clone,
53 send_job_rx,
54 ));
55
56 Self {
57 read_key,
58 read_counter: Arc::new(AtomicU32::new(0)),
59 send_job_tx,
60 sender_task_handle,
61 }
62 }
63
64 async fn sender_task(
68 transport: Arc<dyn Transport>,
69 write_key: Arc<NoiseCipher>,
70 mut send_job_rx: mpsc::Receiver<SendJob>,
71 ) {
72 let mut write_counter: u32 = 0;
73
74 while let Some(job) = send_job_rx.recv().await {
75 let result = Self::process_send_job(
76 &transport,
77 &write_key,
78 &mut write_counter,
79 job.plaintext_buf,
80 job.out_buf,
81 )
82 .await;
83
84 let _ = job.response_tx.send(result);
86 }
87
88 }
90
91 async fn process_send_job(
93 transport: &Arc<dyn Transport>,
94 write_key: &Arc<NoiseCipher>,
95 write_counter: &mut u32,
96 mut plaintext_buf: Vec<u8>,
97 mut out_buf: Vec<u8>,
98 ) -> SendResult {
99 let counter = *write_counter;
100 *write_counter = write_counter.wrapping_add(1);
101
102 if plaintext_buf.len() <= INLINE_ENCRYPT_THRESHOLD {
104 out_buf.clear();
106 out_buf.extend_from_slice(&plaintext_buf);
107 plaintext_buf.clear();
108
109 if let Err(e) = write_key.encrypt_in_place_with_counter(counter, &mut out_buf) {
110 return Err(EncryptSendError::crypto(
111 anyhow::anyhow!(e.to_string()),
112 plaintext_buf,
113 out_buf,
114 ));
115 }
116
117 let ciphertext_len = out_buf.len();
120 plaintext_buf.extend_from_slice(&out_buf);
121 out_buf.clear();
122 if let Err(e) = wacore::framing::encode_frame_into(
123 &plaintext_buf[..ciphertext_len],
124 None,
125 &mut out_buf,
126 ) {
127 plaintext_buf.clear();
128 return Err(EncryptSendError::framing(e, plaintext_buf, out_buf));
129 }
130 plaintext_buf.clear();
131 } else {
132 let write_key = write_key.clone();
134
135 let plaintext_arc = Arc::new(plaintext_buf);
136 let plaintext_arc_for_task = plaintext_arc.clone();
137
138 let spawn_result = tokio::task::spawn_blocking(move || {
139 write_key.encrypt_with_counter(counter, &plaintext_arc_for_task[..])
140 })
141 .await;
142
143 plaintext_buf = Arc::try_unwrap(plaintext_arc).unwrap_or_else(|arc| (*arc).clone());
144
145 let ciphertext = match spawn_result {
146 Ok(Ok(c)) => c,
147 Ok(Err(e)) => {
148 return Err(EncryptSendError::crypto(
149 anyhow::anyhow!(e.to_string()),
150 plaintext_buf,
151 out_buf,
152 ));
153 }
154 Err(join_err) => {
155 return Err(EncryptSendError::join(join_err, plaintext_buf, out_buf));
156 }
157 };
158
159 plaintext_buf.clear();
160 out_buf.clear();
161 if let Err(e) = wacore::framing::encode_frame_into(&ciphertext, None, &mut out_buf) {
162 return Err(EncryptSendError::framing(e, plaintext_buf, out_buf));
163 }
164 }
165
166 if let Err(e) = transport.send(out_buf).await {
167 return Err(EncryptSendError::transport(e, plaintext_buf, Vec::new()));
168 }
169
170 Ok((plaintext_buf, Vec::new()))
171 }
172
173 pub async fn encrypt_and_send(&self, plaintext_buf: Vec<u8>, out_buf: Vec<u8>) -> SendResult {
174 let (response_tx, response_rx) = oneshot::channel();
175
176 let job = SendJob {
177 plaintext_buf,
178 out_buf,
179 response_tx,
180 };
181
182 if let Err(send_err) = self.send_job_tx.send(job).await {
184 let job = send_err.0;
186 return Err(EncryptSendError::channel_closed(
187 job.plaintext_buf,
188 job.out_buf,
189 ));
190 }
191
192 match response_rx.await {
194 Ok(result) => result,
195 Err(_) => {
196 Err(EncryptSendError::channel_closed(Vec::new(), Vec::new()))
198 }
199 }
200 }
201
202 pub fn decrypt_frame(&self, ciphertext: &[u8]) -> Result<Vec<u8>> {
203 let counter = self.read_counter.fetch_add(1, Ordering::SeqCst);
204 self.read_key
205 .decrypt_with_counter(counter, ciphertext)
206 .map_err(|e| SocketError::Crypto(e.to_string()))
207 }
208}
209
210impl Drop for NoiseSocket {
211 fn drop(&mut self) {
212 self.sender_task_handle.abort();
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[tokio::test]
224 async fn test_encrypt_and_send_returns_both_buffers() {
225 let transport = Arc::new(crate::transport::mock::MockTransport);
227
228 let key = [0u8; 32];
230 let write_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
231 let read_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
232
233 let socket = NoiseSocket::new(transport, write_key, read_key);
234
235 let plaintext_buf = Vec::with_capacity(1024);
236 let encrypted_buf = Vec::with_capacity(1024);
237 let plaintext_capacity = plaintext_buf.capacity();
238
239 let result = socket.encrypt_and_send(plaintext_buf, encrypted_buf).await;
240 assert!(result.is_ok(), "encrypt_and_send should succeed");
241
242 let (returned_plaintext, returned_encrypted) = result.unwrap();
243
244 assert_eq!(
245 returned_plaintext.capacity(),
246 plaintext_capacity,
247 "Plaintext buffer should maintain its capacity"
248 );
249 assert!(
250 returned_encrypted.is_empty(),
251 "Encrypted buffer is moved to transport (zero-copy)"
252 );
253 assert!(
254 returned_plaintext.is_empty(),
255 "Returned plaintext buffer should be cleared"
256 );
257 }
258
259 #[tokio::test]
260 async fn test_concurrent_sends_maintain_order() {
261 use async_trait::async_trait;
262 use std::sync::Arc;
263 use tokio::sync::Mutex;
264
265 struct RecordingTransport {
268 recorded_order: Arc<Mutex<Vec<u8>>>,
269 read_key: NoiseCipher,
270 counter: std::sync::atomic::AtomicU32,
271 }
272
273 #[async_trait]
274 impl crate::transport::Transport for RecordingTransport {
275 async fn send(&self, data: Vec<u8>) -> std::result::Result<(), anyhow::Error> {
276 if data.len() > 16 {
278 let ciphertext = &data[3..];
280 let counter = self
281 .counter
282 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
283
284 if let Ok(plaintext) = self.read_key.decrypt_with_counter(counter, ciphertext)
285 && !plaintext.is_empty()
286 {
287 let index = plaintext[0];
288 let mut order = self.recorded_order.lock().await;
289 order.push(index);
290 }
291 }
292 Ok(())
293 }
294
295 async fn disconnect(&self) {}
296 }
297
298 let recorded_order = Arc::new(Mutex::new(Vec::new()));
299 let key = [0u8; 32];
300 let write_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
301 let read_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
302
303 let transport = Arc::new(RecordingTransport {
304 recorded_order: recorded_order.clone(),
305 read_key: NoiseCipher::new(&key).expect("32-byte key should be valid"),
306 counter: std::sync::atomic::AtomicU32::new(0),
307 });
308
309 let socket = Arc::new(NoiseSocket::new(transport, write_key, read_key));
310
311 let mut handles = Vec::new();
313 for i in 0..10 {
314 let socket = socket.clone();
315 handles.push(tokio::spawn(async move {
316 let mut plaintext = vec![i as u8];
318 plaintext.extend_from_slice(&[0u8; 99]);
319 let out_buf = Vec::with_capacity(256);
320 socket.encrypt_and_send(plaintext, out_buf).await
321 }));
322 }
323
324 for handle in handles {
326 let result = handle.await.expect("task should complete");
327 assert!(result.is_ok(), "All sends should succeed");
328 }
329
330 let order = recorded_order.lock().await;
332 let expected: Vec<u8> = (0..10).collect();
333 assert_eq!(*order, expected, "Sends should maintain FIFO order");
334 }
335
336 #[tokio::test]
339 async fn test_encrypted_buffer_sizing_is_sufficient() {
340 use async_trait::async_trait;
341 use std::sync::Arc;
342 use std::sync::atomic::{AtomicUsize, Ordering};
343
344 struct SizeRecordingTransport {
346 last_size: Arc<AtomicUsize>,
347 }
348
349 #[async_trait]
350 impl crate::transport::Transport for SizeRecordingTransport {
351 async fn send(&self, data: Vec<u8>) -> std::result::Result<(), anyhow::Error> {
352 self.last_size.store(data.len(), Ordering::SeqCst);
353 Ok(())
354 }
355 async fn disconnect(&self) {}
356 }
357
358 let last_size = Arc::new(AtomicUsize::new(0));
359 let transport = Arc::new(SizeRecordingTransport {
360 last_size: last_size.clone(),
361 });
362
363 let key = [0u8; 32];
364 let write_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
365 let read_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
366
367 let socket = NoiseSocket::new(transport, write_key, read_key);
368
369 let test_sizes = [0, 1, 50, 100, 500, 1000, 1024, 2000, 5000, 16384, 20000];
371
372 for size in test_sizes {
373 let plaintext = vec![0xABu8; size];
374 let buffer_capacity = plaintext.len() + 32;
376 let encrypted_buf = Vec::with_capacity(buffer_capacity);
377
378 let result = socket
379 .encrypt_and_send(plaintext.clone(), encrypted_buf)
380 .await;
381
382 assert!(
383 result.is_ok(),
384 "encrypt_and_send should succeed for payload size {}",
385 size
386 );
387
388 let actual_encrypted_size = last_size.load(Ordering::SeqCst);
389
390 let expected_max = size + 19;
393 assert_eq!(
394 actual_encrypted_size, expected_max,
395 "Encrypted size for {} byte payload should be {} (got {})",
396 size, expected_max, actual_encrypted_size
397 );
398
399 assert!(
401 buffer_capacity >= actual_encrypted_size,
402 "Buffer capacity {} should be >= encrypted size {} for payload size {}",
403 buffer_capacity,
404 actual_encrypted_size,
405 size
406 );
407 }
408 }
409
410 #[tokio::test]
412 async fn test_encrypted_buffer_sizing_edge_cases() {
413 use async_trait::async_trait;
414 use std::sync::Arc;
415
416 struct NoOpTransport;
417
418 #[async_trait]
419 impl crate::transport::Transport for NoOpTransport {
420 async fn send(&self, _data: Vec<u8>) -> std::result::Result<(), anyhow::Error> {
421 Ok(())
422 }
423 async fn disconnect(&self) {}
424 }
425
426 let transport = Arc::new(NoOpTransport);
427 let key = [0u8; 32];
428 let write_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
429 let read_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
430
431 let socket = NoiseSocket::new(transport, write_key, read_key);
432
433 let result = socket
435 .encrypt_and_send(vec![], Vec::with_capacity(32))
436 .await;
437 assert!(result.is_ok(), "Empty payload should encrypt successfully");
438
439 let at_threshold = vec![0u8; 16 * 1024];
441 let result = socket
442 .encrypt_and_send(at_threshold, Vec::with_capacity(16 * 1024 + 32))
443 .await;
444 assert!(
445 result.is_ok(),
446 "Payload at inline threshold should encrypt successfully"
447 );
448
449 let above_threshold = vec![0u8; 16 * 1024 + 1];
451 let result = socket
452 .encrypt_and_send(above_threshold, Vec::with_capacity(16 * 1024 + 33))
453 .await;
454 assert!(
455 result.is_ok(),
456 "Payload above inline threshold should encrypt successfully"
457 );
458 }
459}