1use crate::socket::error::{EncryptSendError, Result, SocketError};
2use crate::transport::Transport;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU32, Ordering};
5use tokio::sync::{mpsc, oneshot};
6use tokio::task::JoinHandle;
7use wacore::aes_gcm::{
8 Aes256Gcm,
9 aead::{Aead, AeadInPlace},
10};
11use wacore::handshake::utils::generate_iv;
12
13const INLINE_ENCRYPT_THRESHOLD: usize = 16 * 1024;
14
15type SendResult = std::result::Result<(Vec<u8>, Vec<u8>), EncryptSendError>;
17
18struct SendJob {
20 plaintext_buf: Vec<u8>,
21 out_buf: Vec<u8>,
22 response_tx: oneshot::Sender<SendResult>,
23}
24
25pub struct NoiseSocket {
26 read_key: Arc<Aes256Gcm>,
27 read_counter: Arc<AtomicU32>,
28 send_job_tx: mpsc::Sender<SendJob>,
33 sender_task_handle: JoinHandle<()>,
36}
37
38impl NoiseSocket {
39 pub fn new(transport: Arc<dyn Transport>, write_key: Aes256Gcm, read_key: Aes256Gcm) -> Self {
40 let write_key = Arc::new(write_key);
41 let read_key = Arc::new(read_key);
42
43 let (send_job_tx, send_job_rx) = mpsc::channel::<SendJob>(32);
46
47 let transport_clone = transport.clone();
49 let write_key_clone = write_key.clone();
50 let sender_task_handle = tokio::spawn(Self::sender_task(
51 transport_clone,
52 write_key_clone,
53 send_job_rx,
54 ));
55
56 Self {
57 read_key,
58 read_counter: Arc::new(AtomicU32::new(0)),
59 send_job_tx,
60 sender_task_handle,
61 }
62 }
63
64 async fn sender_task(
68 transport: Arc<dyn Transport>,
69 write_key: Arc<Aes256Gcm>,
70 mut send_job_rx: mpsc::Receiver<SendJob>,
71 ) {
72 let mut write_counter: u32 = 0;
73
74 while let Some(job) = send_job_rx.recv().await {
75 let result = Self::process_send_job(
76 &transport,
77 &write_key,
78 &mut write_counter,
79 job.plaintext_buf,
80 job.out_buf,
81 )
82 .await;
83
84 let _ = job.response_tx.send(result);
86 }
87
88 }
90
91 async fn process_send_job(
93 transport: &Arc<dyn Transport>,
94 write_key: &Arc<Aes256Gcm>,
95 write_counter: &mut u32,
96 mut plaintext_buf: Vec<u8>,
97 mut out_buf: Vec<u8>,
98 ) -> SendResult {
99 let counter = *write_counter;
100 *write_counter = write_counter.wrapping_add(1);
101
102 if plaintext_buf.len() <= INLINE_ENCRYPT_THRESHOLD {
104 out_buf.clear();
106 out_buf.extend_from_slice(&plaintext_buf);
107 plaintext_buf.clear();
108
109 let iv = generate_iv(counter);
110 if let Err(e) = write_key.encrypt_in_place(iv.as_ref().into(), b"", &mut out_buf) {
111 return Err(EncryptSendError::crypto(
112 anyhow::anyhow!(e.to_string()),
113 plaintext_buf,
114 out_buf,
115 ));
116 }
117
118 let ciphertext_len = out_buf.len();
121 plaintext_buf.extend_from_slice(&out_buf);
122 out_buf.clear();
123 if let Err(e) = wacore::framing::encode_frame_into(
124 &plaintext_buf[..ciphertext_len],
125 None,
126 &mut out_buf,
127 ) {
128 plaintext_buf.clear();
129 return Err(EncryptSendError::framing(e, plaintext_buf, out_buf));
130 }
131 plaintext_buf.clear();
132 } else {
133 let write_key = write_key.clone();
135
136 let plaintext_arc = Arc::new(plaintext_buf);
137 let plaintext_arc_for_task = plaintext_arc.clone();
138
139 let spawn_result = tokio::task::spawn_blocking(move || {
140 let iv = generate_iv(counter);
141 write_key.encrypt(iv.as_ref().into(), &plaintext_arc_for_task[..])
142 })
143 .await;
144
145 plaintext_buf = Arc::try_unwrap(plaintext_arc).unwrap_or_else(|arc| (*arc).clone());
146
147 let ciphertext = match spawn_result {
148 Ok(Ok(c)) => c,
149 Ok(Err(e)) => {
150 return Err(EncryptSendError::crypto(
151 anyhow::anyhow!(e.to_string()),
152 plaintext_buf,
153 out_buf,
154 ));
155 }
156 Err(join_err) => {
157 return Err(EncryptSendError::join(join_err, plaintext_buf, out_buf));
158 }
159 };
160
161 plaintext_buf.clear();
162 out_buf.clear();
163 if let Err(e) = wacore::framing::encode_frame_into(&ciphertext, None, &mut out_buf) {
164 return Err(EncryptSendError::framing(e, plaintext_buf, out_buf));
165 }
166 }
167
168 if let Err(e) = transport.send(&out_buf).await {
169 return Err(EncryptSendError::transport(e, plaintext_buf, out_buf));
170 }
171
172 out_buf.clear();
173 Ok((plaintext_buf, out_buf))
174 }
175
176 pub async fn encrypt_and_send(&self, plaintext_buf: Vec<u8>, out_buf: Vec<u8>) -> SendResult {
177 let (response_tx, response_rx) = oneshot::channel();
178
179 let job = SendJob {
180 plaintext_buf,
181 out_buf,
182 response_tx,
183 };
184
185 if let Err(send_err) = self.send_job_tx.send(job).await {
187 let job = send_err.0;
189 return Err(EncryptSendError::channel_closed(
190 job.plaintext_buf,
191 job.out_buf,
192 ));
193 }
194
195 match response_rx.await {
197 Ok(result) => result,
198 Err(_) => {
199 Err(EncryptSendError::channel_closed(Vec::new(), Vec::new()))
201 }
202 }
203 }
204
205 pub fn decrypt_frame(&self, ciphertext: &[u8]) -> Result<Vec<u8>> {
206 let counter = self.read_counter.fetch_add(1, Ordering::SeqCst);
207 let iv = generate_iv(counter);
208 self.read_key
209 .decrypt(iv.as_ref().into(), ciphertext)
210 .map_err(|e| SocketError::Crypto(e.to_string()))
211 }
212}
213
214impl Drop for NoiseSocket {
215 fn drop(&mut self) {
216 self.sender_task_handle.abort();
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226 use wacore::aes_gcm::{Aes256Gcm, KeyInit};
227
228 #[tokio::test]
229 async fn test_encrypt_and_send_returns_both_buffers() {
230 let transport = Arc::new(crate::transport::mock::MockTransport);
232
233 let key = [0u8; 32];
235 let write_key =
236 Aes256Gcm::new_from_slice(&key).expect("32-byte key should be valid for AES-256-GCM");
237 let read_key =
238 Aes256Gcm::new_from_slice(&key).expect("32-byte key should be valid for AES-256-GCM");
239
240 let socket = NoiseSocket::new(transport, write_key, read_key);
241
242 let plaintext_buf = Vec::with_capacity(1024);
244 let encrypted_buf = Vec::with_capacity(1024);
245
246 let plaintext_capacity = plaintext_buf.capacity();
248 let encrypted_capacity = encrypted_buf.capacity();
249
250 let result = socket.encrypt_and_send(plaintext_buf, encrypted_buf).await;
252
253 assert!(result.is_ok(), "encrypt_and_send should succeed");
254
255 let (returned_plaintext, returned_encrypted) =
256 result.expect("encrypt_and_send result should unwrap after is_ok check");
257
258 assert_eq!(
260 returned_plaintext.capacity(),
261 plaintext_capacity,
262 "Plaintext buffer should maintain its capacity"
263 );
264 assert_eq!(
265 returned_encrypted.capacity(),
266 encrypted_capacity,
267 "Encrypted buffer should maintain its capacity"
268 );
269
270 assert!(
272 returned_plaintext.is_empty(),
273 "Returned plaintext buffer should be cleared"
274 );
275 assert!(
276 returned_encrypted.is_empty(),
277 "Returned encrypted buffer should be cleared"
278 );
279 }
280
281 #[tokio::test]
282 async fn test_concurrent_sends_maintain_order() {
283 use async_trait::async_trait;
284 use std::sync::Arc;
285 use tokio::sync::Mutex;
286
287 struct RecordingTransport {
290 recorded_order: Arc<Mutex<Vec<u8>>>,
291 read_key: Aes256Gcm,
292 counter: std::sync::atomic::AtomicU32,
293 }
294
295 #[async_trait]
296 impl crate::transport::Transport for RecordingTransport {
297 async fn send(&self, data: &[u8]) -> std::result::Result<(), anyhow::Error> {
298 if data.len() > 16 {
300 let ciphertext = &data[3..];
302 let counter = self
303 .counter
304 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
305 let iv = super::generate_iv(counter);
306
307 if let Ok(plaintext) = self.read_key.decrypt(iv.as_ref().into(), ciphertext)
308 && !plaintext.is_empty()
309 {
310 let index = plaintext[0];
311 let mut order = self.recorded_order.lock().await;
312 order.push(index);
313 }
314 }
315 Ok(())
316 }
317
318 async fn disconnect(&self) {}
319 }
320
321 let recorded_order = Arc::new(Mutex::new(Vec::new()));
322 let key = [0u8; 32];
323 let write_key =
324 Aes256Gcm::new_from_slice(&key).expect("32-byte key should be valid for AES-256-GCM");
325 let read_key =
326 Aes256Gcm::new_from_slice(&key).expect("32-byte key should be valid for AES-256-GCM");
327
328 let transport = Arc::new(RecordingTransport {
329 recorded_order: recorded_order.clone(),
330 read_key: Aes256Gcm::new_from_slice(&key)
331 .expect("32-byte key should be valid for AES-256-GCM"),
332 counter: std::sync::atomic::AtomicU32::new(0),
333 });
334
335 let socket = Arc::new(NoiseSocket::new(transport, write_key, read_key));
336
337 let mut handles = Vec::new();
339 for i in 0..10 {
340 let socket = socket.clone();
341 handles.push(tokio::spawn(async move {
342 let mut plaintext = vec![i as u8];
344 plaintext.extend_from_slice(&[0u8; 99]);
345 let out_buf = Vec::with_capacity(256);
346 socket.encrypt_and_send(plaintext, out_buf).await
347 }));
348 }
349
350 for handle in handles {
352 let result = handle.await.expect("task should complete");
353 assert!(result.is_ok(), "All sends should succeed");
354 }
355
356 let order = recorded_order.lock().await;
358 let expected: Vec<u8> = (0..10).collect();
359 assert_eq!(*order, expected, "Sends should maintain FIFO order");
360 }
361}