whatsapp_rust/socket/
frame_socket.rs

1use crate::socket::consts::{FRAME_LENGTH_SIZE, FRAME_MAX_SIZE, URL};
2use crate::socket::error::{Result, SocketError};
3use bytes::{Buf, BytesMut};
4use futures_util::stream::{SplitSink, SplitStream};
5use futures_util::{SinkExt, StreamExt};
6use log::{debug, error, info, trace, warn};
7use std::sync::Arc;
8use tokio::net::TcpStream;
9use tokio::sync::Mutex;
10use tokio::sync::mpsc::{self, Receiver, Sender};
11use tokio_websockets::{ClientBuilder, MaybeTlsStream, Message, WebSocketStream};
12type RawWs = WebSocketStream<MaybeTlsStream<TcpStream>>;
13type WsSink = SplitSink<RawWs, Message>;
14type WsStream = SplitStream<RawWs>;
15use wacore_binary::consts::WA_CONN_HEADER;
16
17type OnDisconnectCallback = Box<dyn Fn(bool) + Send>;
18
19pub struct FrameSocket {
20    ws_sink: Arc<Mutex<Option<WsSink>>>,
21    frames_tx: Sender<bytes::Bytes>,
22    on_disconnect: Arc<Mutex<Option<OnDisconnectCallback>>>,
23    is_connected: Arc<Mutex<bool>>,
24    header: Arc<Mutex<Option<Vec<u8>>>>,
25}
26
27impl FrameSocket {
28    pub fn new() -> (Self, Receiver<bytes::Bytes>) {
29        if let Err(e) = rustls::crypto::ring::default_provider().install_default() {
30            debug!("rustls crypto provider install: {:?}", e);
31        }
32
33        let (tx, rx) = mpsc::channel(100);
34        let socket = Self {
35            ws_sink: Arc::new(Mutex::new(None)),
36            frames_tx: tx,
37            on_disconnect: Arc::new(Mutex::new(None)),
38            is_connected: Arc::new(Mutex::new(false)),
39            header: Arc::new(Mutex::new(Some(WA_CONN_HEADER.to_vec()))),
40        };
41        (socket, rx)
42    }
43
44    pub async fn is_connected(&self) -> bool {
45        *self.is_connected.lock().await
46    }
47
48    pub async fn set_on_disconnect(&self, cb: OnDisconnectCallback) {
49        *self.on_disconnect.lock().await = Some(cb);
50    }
51
52    pub async fn connect(&self) -> Result<()> {
53        if self.is_connected().await {
54            return Err(SocketError::SocketAlreadyOpen);
55        }
56
57        info!("Dialing {URL}");
58        let uri: http::Uri = URL.parse().expect("Failed to parse URL");
59        let (client, _response) = match ClientBuilder::from_uri(uri).connect().await {
60            Ok(ok) => ok,
61            Err(e) => {
62                error!("WebSocket connect failed: {e:?}");
63                return Err(SocketError::WebSocket(e));
64            }
65        };
66
67        let (sink, stream) = client.split();
68        *self.ws_sink.lock().await = Some(sink);
69        *self.is_connected.lock().await = true;
70
71        let frames_tx_clone = self.frames_tx.clone();
72        let is_connected_clone = self.is_connected.clone();
73        let on_disconnect_clone = self.on_disconnect.clone();
74
75        tokio::task::spawn(Self::read_pump(
76            stream,
77            frames_tx_clone,
78            is_connected_clone,
79            on_disconnect_clone,
80        ));
81
82        Ok(())
83    }
84
85    pub async fn send_frame(&self, mut data: Vec<u8>) -> Result<()> {
86        let mut sink_guard = self.ws_sink.lock().await;
87        let sink = sink_guard.as_mut().ok_or(SocketError::SocketClosed)?;
88
89        let data_len = data.len();
90        if data_len >= FRAME_MAX_SIZE {
91            return Err(SocketError::FrameTooLarge {
92                max: FRAME_MAX_SIZE,
93                got: data_len,
94            });
95        }
96
97        // Take (or empty) the header (conn header only needed once; subsequent calls will get empty vec)
98        let frame_header = self.header.lock().await.take().unwrap_or_default();
99        let header_len = frame_header.len();
100        let prefix_len = header_len + FRAME_LENGTH_SIZE;
101
102        data.reserve(prefix_len);
103        let original_len = data.len();
104        data.resize(original_len + prefix_len, 0);
105        data.copy_within(0..original_len, prefix_len);
106
107        // Write header (if any) and 3-byte length (big-endian, 24-bit like existing logic).
108        if header_len > 0 {
109            data[0..header_len].copy_from_slice(&frame_header);
110        }
111        let len_bytes = u32::to_be_bytes(data_len as u32);
112        data[header_len..prefix_len].copy_from_slice(&len_bytes[1..]);
113
114        debug!(
115            "--> Sending frame: payload {} bytes, total {} bytes",
116            data_len,
117            data.len()
118        );
119        sink.send(Message::binary(data)).await?;
120        Ok(())
121    }
122
123    async fn read_pump(
124        mut stream: WsStream,
125        frames_tx: mpsc::Sender<bytes::Bytes>,
126        is_connected: Arc<Mutex<bool>>,
127        on_disconnect: Arc<Mutex<Option<OnDisconnectCallback>>>,
128    ) {
129        let mut buffer = BytesMut::new();
130
131        loop {
132            match stream.next().await {
133                Some(Ok(msg)) => {
134                    if msg.is_binary() {
135                        let data = msg.as_payload();
136                        debug!("<-- Received WebSocket message: {} bytes", data.len());
137                        buffer.extend_from_slice(data);
138
139                        while buffer.len() >= FRAME_LENGTH_SIZE {
140                            let frame_len = ((buffer[0] as usize) << 16)
141                                | ((buffer[1] as usize) << 8)
142                                | (buffer[2] as usize);
143
144                            if buffer.len() >= FRAME_LENGTH_SIZE + frame_len {
145                                buffer.advance(FRAME_LENGTH_SIZE);
146                                let frame_data = buffer.split_to(frame_len).freeze();
147                                trace!("<-- Assembled frame: {} bytes", frame_data.len());
148                                if frames_tx.send(frame_data).await.is_err() {
149                                    warn!("Frame receiver dropped, closing read pump");
150                                    break;
151                                }
152                            } else {
153                                break;
154                            }
155                        }
156                    } else if msg.is_close() {
157                        trace!("Received close frame");
158                        break;
159                    }
160                }
161                Some(Err(e)) => {
162                    error!("Error reading from websocket: {e}");
163                    break;
164                }
165                None => {
166                    trace!("Websocket stream ended");
167                    break;
168                }
169            }
170        }
171
172        *is_connected.lock().await = false;
173        if let Some(cb) = on_disconnect.lock().await.as_ref() {
174            (cb)(true);
175        }
176    }
177
178    pub async fn close(&self) {
179        let mut is_connected = self.is_connected.lock().await;
180        if *is_connected {
181            *is_connected = false;
182            *self.ws_sink.lock().await = None;
183            if let Some(cb) = self.on_disconnect.lock().await.as_ref() {
184                (cb)(false);
185            }
186        }
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    #[test]
195    fn test_prefix_shift_algorithm() {
196        let header = WA_CONN_HEADER.to_vec();
197        let header_len = header.len();
198        let payload: Vec<u8> = vec![1, 2, 3, 4, 5];
199        let mut data = payload.clone();
200        let prefix_len = header_len + FRAME_LENGTH_SIZE;
201        let original_len = data.len();
202        data.reserve(prefix_len);
203        data.resize(original_len + prefix_len, 0);
204        data.copy_within(0..original_len, prefix_len);
205        if header_len > 0 {
206            data[0..header_len].copy_from_slice(&header);
207        }
208        let len_bytes = u32::to_be_bytes(original_len as u32);
209        data[header_len..prefix_len].copy_from_slice(&len_bytes[1..]);
210
211        assert_eq!(&data[0..header_len], &header[..]);
212        let reported_len = ((data[header_len] as usize) << 16)
213            | ((data[header_len + 1] as usize) << 8)
214            | (data[header_len + 2] as usize);
215        assert_eq!(reported_len, original_len);
216        assert_eq!(&data[prefix_len..prefix_len + original_len], &payload[..]);
217    }
218}