1use crate::socket::error::{EncryptSendError, Result, SocketError};
2use crate::transport::Transport;
3use async_channel;
4use futures::channel::oneshot;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU32, Ordering};
7use wacore::handshake::NoiseCipher;
8use wacore::runtime::{AbortHandle, Runtime};
9
10const INLINE_ENCRYPT_THRESHOLD: usize = 16 * 1024;
11
12type SendResult = std::result::Result<(), EncryptSendError>;
14
15struct SendJob {
17 plaintext_buf: Vec<u8>,
18 out_buf: Vec<u8>,
19 response_tx: oneshot::Sender<SendResult>,
20}
21
22pub struct NoiseSocket {
23 #[allow(dead_code)] runtime: Arc<dyn Runtime>,
25 read_key: Arc<NoiseCipher>,
26 read_counter: Arc<AtomicU32>,
27 send_job_tx: async_channel::Sender<SendJob>,
32 _sender_task_handle: AbortHandle,
35}
36
37impl NoiseSocket {
38 pub fn new(
39 runtime: Arc<dyn Runtime>,
40 transport: Arc<dyn Transport>,
41 write_key: NoiseCipher,
42 read_key: NoiseCipher,
43 ) -> Self {
44 let write_key = Arc::new(write_key);
45 let read_key = Arc::new(read_key);
46
47 let (send_job_tx, send_job_rx) = async_channel::bounded::<SendJob>(32);
50
51 let transport_clone = transport.clone();
53 let write_key_clone = write_key.clone();
54 let rt_clone = runtime.clone();
55 let sender_task_handle = runtime.spawn(Box::pin(Self::sender_task(
56 rt_clone,
57 transport_clone,
58 write_key_clone,
59 send_job_rx,
60 )));
61
62 Self {
63 runtime,
64 read_key,
65 read_counter: Arc::new(AtomicU32::new(0)),
66 send_job_tx,
67 _sender_task_handle: sender_task_handle,
68 }
69 }
70
71 async fn sender_task(
75 runtime: Arc<dyn Runtime>,
76 transport: Arc<dyn Transport>,
77 write_key: Arc<NoiseCipher>,
78 send_job_rx: async_channel::Receiver<SendJob>,
79 ) {
80 let mut write_counter: u32 = 0;
81
82 while let Ok(job) = send_job_rx.recv().await {
83 let result = Self::process_send_job(
84 &runtime,
85 &transport,
86 &write_key,
87 &mut write_counter,
88 job.plaintext_buf,
89 job.out_buf,
90 )
91 .await;
92
93 let _ = job.response_tx.send(result);
95 }
96
97 }
99
100 async fn process_send_job(
102 runtime: &Arc<dyn Runtime>,
103 transport: &Arc<dyn Transport>,
104 write_key: &Arc<NoiseCipher>,
105 write_counter: &mut u32,
106 mut plaintext_buf: Vec<u8>,
107 mut out_buf: Vec<u8>,
108 ) -> SendResult {
109 let counter = *write_counter;
110
111 if plaintext_buf.len() <= INLINE_ENCRYPT_THRESHOLD {
114 if let Err(e) = write_key.encrypt_in_place_with_counter(counter, &mut plaintext_buf) {
115 return Err(EncryptSendError::crypto(anyhow::anyhow!(e.to_string())));
116 }
117
118 out_buf.clear();
120 if let Err(e) = wacore::framing::encode_frame_into(&plaintext_buf, None, &mut out_buf) {
121 return Err(EncryptSendError::framing(e));
122 }
123 } else {
124 let write_key = write_key.clone();
126
127 let plaintext_arc = Arc::new(plaintext_buf);
128 let plaintext_arc_for_task = plaintext_arc.clone();
129
130 let encrypt_result = wacore::runtime::blocking(&**runtime, move || {
131 write_key.encrypt_with_counter(counter, &plaintext_arc_for_task[..])
132 })
133 .await;
134
135 plaintext_buf = Arc::try_unwrap(plaintext_arc).unwrap_or_else(|arc| (*arc).clone());
137 drop(plaintext_buf);
138
139 let ciphertext = match encrypt_result {
140 Ok(c) => c,
141 Err(e) => {
142 return Err(EncryptSendError::crypto(anyhow::anyhow!(e.to_string())));
143 }
144 };
145
146 out_buf.clear();
147 if let Err(e) = wacore::framing::encode_frame_into(&ciphertext, None, &mut out_buf) {
148 return Err(EncryptSendError::framing(e));
149 }
150 }
151
152 if let Err(e) = transport.send(out_buf).await {
153 return Err(EncryptSendError::transport(e));
154 }
155
156 *write_counter = write_counter.wrapping_add(1);
159
160 Ok(())
161 }
162
163 pub async fn encrypt_and_send(&self, plaintext_buf: Vec<u8>, out_buf: Vec<u8>) -> SendResult {
164 let (response_tx, response_rx) = oneshot::channel();
165
166 let job = SendJob {
167 plaintext_buf,
168 out_buf,
169 response_tx,
170 };
171
172 if let Err(_send_err) = self.send_job_tx.send(job).await {
174 return Err(EncryptSendError::channel_closed());
175 }
176
177 match response_rx.await {
179 Ok(result) => result,
180 Err(_) => {
181 Err(EncryptSendError::channel_closed())
183 }
184 }
185 }
186
187 pub fn decrypt_frame(&self, ciphertext: &[u8]) -> Result<Vec<u8>> {
188 let counter = self.read_counter.fetch_add(1, Ordering::SeqCst);
189 self.read_key
190 .decrypt_with_counter(counter, ciphertext)
191 .map_err(|e| SocketError::Crypto(e.to_string()))
192 }
193}
194
195#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[tokio::test]
203 async fn test_encrypt_and_send_succeeds() {
204 let transport = Arc::new(crate::transport::mock::MockTransport);
205
206 let key = [0u8; 32];
207 let write_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
208 let read_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
209
210 let socket = NoiseSocket::new(
211 Arc::new(crate::runtime_impl::TokioRuntime),
212 transport,
213 write_key,
214 read_key,
215 );
216
217 let plaintext_buf = Vec::with_capacity(1024);
218 let encrypted_buf = Vec::with_capacity(1024);
219
220 let result = socket.encrypt_and_send(plaintext_buf, encrypted_buf).await;
221 assert!(result.is_ok(), "encrypt_and_send should succeed");
222 }
223
224 #[tokio::test]
225 async fn test_concurrent_sends_maintain_order() {
226 use async_lock::Mutex;
227 use async_trait::async_trait;
228 use std::sync::Arc;
229
230 struct RecordingTransport {
233 recorded_order: Arc<Mutex<Vec<u8>>>,
234 read_key: NoiseCipher,
235 counter: std::sync::atomic::AtomicU32,
236 }
237
238 #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
239 #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
240 impl crate::transport::Transport for RecordingTransport {
241 async fn send(&self, data: Vec<u8>) -> std::result::Result<(), anyhow::Error> {
242 if data.len() > 16 {
244 let ciphertext = &data[3..];
246 let counter = self
247 .counter
248 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
249
250 if let Ok(plaintext) = self.read_key.decrypt_with_counter(counter, ciphertext)
251 && !plaintext.is_empty()
252 {
253 let index = plaintext[0];
254 let mut order = self.recorded_order.lock().await;
255 order.push(index);
256 }
257 }
258 Ok(())
259 }
260
261 async fn disconnect(&self) {}
262 }
263
264 let recorded_order = Arc::new(Mutex::new(Vec::new()));
265 let key = [0u8; 32];
266 let write_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
267 let read_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
268
269 let transport = Arc::new(RecordingTransport {
270 recorded_order: recorded_order.clone(),
271 read_key: NoiseCipher::new(&key).expect("32-byte key should be valid"),
272 counter: std::sync::atomic::AtomicU32::new(0),
273 });
274
275 let socket = Arc::new(NoiseSocket::new(
276 Arc::new(crate::runtime_impl::TokioRuntime),
277 transport,
278 write_key,
279 read_key,
280 ));
281
282 let mut handles = Vec::new();
284 for i in 0..10 {
285 let socket = socket.clone();
286 handles.push(tokio::spawn(async move {
287 let mut plaintext = vec![i as u8];
289 plaintext.extend_from_slice(&[0u8; 99]);
290 let out_buf = Vec::with_capacity(256);
291 socket.encrypt_and_send(plaintext, out_buf).await
292 }));
293 }
294
295 for handle in handles {
297 let result = handle.await.expect("task should complete");
298 assert!(result.is_ok(), "All sends should succeed");
299 }
300
301 let order = recorded_order.lock().await;
303 let expected: Vec<u8> = (0..10).collect();
304 assert_eq!(*order, expected, "Sends should maintain FIFO order");
305 }
306
307 #[tokio::test]
310 async fn test_encrypted_buffer_sizing_is_sufficient() {
311 use async_trait::async_trait;
312 use std::sync::Arc;
313 use std::sync::atomic::{AtomicUsize, Ordering};
314
315 struct SizeRecordingTransport {
317 last_size: Arc<AtomicUsize>,
318 }
319
320 #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
321 #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
322 impl crate::transport::Transport for SizeRecordingTransport {
323 async fn send(&self, data: Vec<u8>) -> std::result::Result<(), anyhow::Error> {
324 self.last_size.store(data.len(), Ordering::SeqCst);
325 Ok(())
326 }
327 async fn disconnect(&self) {}
328 }
329
330 let last_size = Arc::new(AtomicUsize::new(0));
331 let transport = Arc::new(SizeRecordingTransport {
332 last_size: last_size.clone(),
333 });
334
335 let key = [0u8; 32];
336 let write_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
337 let read_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
338
339 let socket = NoiseSocket::new(
340 Arc::new(crate::runtime_impl::TokioRuntime),
341 transport,
342 write_key,
343 read_key,
344 );
345
346 let test_sizes = [0, 1, 50, 100, 500, 1000, 1024, 2000, 5000, 16384, 20000];
348
349 for size in test_sizes {
350 let plaintext = vec![0xABu8; size];
351 let buffer_capacity = plaintext.len() + 32;
353 let encrypted_buf = Vec::with_capacity(buffer_capacity);
354
355 let result = socket
356 .encrypt_and_send(plaintext.clone(), encrypted_buf)
357 .await;
358
359 assert!(
360 result.is_ok(),
361 "encrypt_and_send should succeed for payload size {}",
362 size
363 );
364
365 let actual_encrypted_size = last_size.load(Ordering::SeqCst);
366
367 let expected_max = size + 19;
370 assert_eq!(
371 actual_encrypted_size, expected_max,
372 "Encrypted size for {} byte payload should be {} (got {})",
373 size, expected_max, actual_encrypted_size
374 );
375
376 assert!(
378 buffer_capacity >= actual_encrypted_size,
379 "Buffer capacity {} should be >= encrypted size {} for payload size {}",
380 buffer_capacity,
381 actual_encrypted_size,
382 size
383 );
384 }
385 }
386
387 #[tokio::test]
389 async fn test_encrypted_buffer_sizing_edge_cases() {
390 use async_trait::async_trait;
391 use std::sync::Arc;
392
393 struct NoOpTransport;
394
395 #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
396 #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
397 impl crate::transport::Transport for NoOpTransport {
398 async fn send(&self, _data: Vec<u8>) -> std::result::Result<(), anyhow::Error> {
399 Ok(())
400 }
401 async fn disconnect(&self) {}
402 }
403
404 let transport = Arc::new(NoOpTransport);
405 let key = [0u8; 32];
406 let write_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
407 let read_key = NoiseCipher::new(&key).expect("32-byte key should be valid");
408
409 let socket = NoiseSocket::new(
410 Arc::new(crate::runtime_impl::TokioRuntime),
411 transport,
412 write_key,
413 read_key,
414 );
415
416 let result = socket
418 .encrypt_and_send(vec![], Vec::with_capacity(32))
419 .await;
420 assert!(result.is_ok(), "Empty payload should encrypt successfully");
421
422 let at_threshold = vec![0u8; 16 * 1024];
424 let result = socket
425 .encrypt_and_send(at_threshold, Vec::with_capacity(16 * 1024 + 32))
426 .await;
427 assert!(
428 result.is_ok(),
429 "Payload at inline threshold should encrypt successfully"
430 );
431
432 let above_threshold = vec![0u8; 16 * 1024 + 1];
434 let result = socket
435 .encrypt_and_send(above_threshold, Vec::with_capacity(16 * 1024 + 33))
436 .await;
437 assert!(
438 result.is_ok(),
439 "Payload above inline threshold should encrypt successfully"
440 );
441 }
442}