whatsapp_rust/socket/
noise_socket.rs

1use crate::socket::error::{Result, SocketError};
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU32, Ordering};
4use wacore::aes_gcm::{Aes256Gcm, aead::Aead};
5use wacore::handshake::utils::generate_iv;
6
7use crate::socket::FrameSocket;
8use tokio::sync::Mutex;
9
10pub struct NoiseSocket {
11    frame_socket: Arc<Mutex<FrameSocket>>,
12    write_key: Aes256Gcm,
13    read_key: Aes256Gcm,
14    write_counter: Arc<AtomicU32>,
15    read_counter: Arc<AtomicU32>,
16}
17
18impl NoiseSocket {
19    pub fn new(
20        frame_socket: Arc<Mutex<FrameSocket>>,
21        write_key: Aes256Gcm,
22        read_key: Aes256Gcm,
23    ) -> Self {
24        Self {
25            frame_socket,
26            write_key,
27            read_key,
28            write_counter: Arc::new(AtomicU32::new(0)),
29            read_counter: Arc::new(AtomicU32::new(0)),
30        }
31    }
32
33    /// Encrypts `plaintext` into the provided `out` buffer (which is cleared first) and
34    /// returns a slice view of the ciphertext.
35    pub fn encrypt_into<'a>(&self, plaintext: &[u8], out: &'a mut Vec<u8>) -> Result<&'a [u8]> {
36        out.clear();
37        let counter = self.write_counter.fetch_add(1, Ordering::SeqCst);
38        let iv = generate_iv(counter);
39        let ciphertext = self
40            .write_key
41            .encrypt(iv.as_ref().into(), plaintext)
42            .map_err(|e| SocketError::Crypto(e.to_string()))?;
43        out.extend_from_slice(&ciphertext);
44        Ok(out.as_slice())
45    }
46
47    pub async fn encrypt_and_send(
48        &self,
49        mut plaintext_buf: Vec<u8>,
50        mut out_buf: Vec<u8>,
51    ) -> Result<Vec<u8>> {
52        self.encrypt_into(&plaintext_buf, &mut out_buf)?;
53        plaintext_buf.clear();
54        let fs = self.frame_socket.clone();
55        fs.lock().await.send_frame(out_buf).await?;
56        Ok(plaintext_buf)
57    }
58
59    pub fn decrypt_frame(&self, ciphertext: &[u8]) -> Result<Vec<u8>> {
60        let counter = self.read_counter.fetch_add(1, Ordering::SeqCst);
61        let iv = generate_iv(counter);
62        self.read_key
63            .decrypt(iv.as_ref().into(), ciphertext)
64            .map_err(|e| SocketError::Crypto(e.to_string()))
65    }
66}