whatsapp_rust/socket/
frame_socket.rs1use crate::socket::consts::{FRAME_LENGTH_SIZE, FRAME_MAX_SIZE, URL};
2use crate::socket::error::{Result, SocketError};
3use bytes::{Buf, BytesMut};
4use futures_util::stream::{SplitSink, SplitStream};
5use futures_util::{SinkExt, StreamExt};
6use log::{debug, error, info, trace, warn};
7use std::sync::Arc;
8use tokio::net::TcpStream;
9use tokio::sync::Mutex;
10use tokio::sync::mpsc::{self, Receiver, Sender};
11use tokio_websockets::{ClientBuilder, MaybeTlsStream, Message, WebSocketStream};
12type RawWs = WebSocketStream<MaybeTlsStream<TcpStream>>;
13type WsSink = SplitSink<RawWs, Message>;
14type WsStream = SplitStream<RawWs>;
15use wacore_binary::consts::WA_CONN_HEADER;
16
17type OnDisconnectCallback = Box<dyn Fn(bool) + Send>;
18
19pub struct FrameSocket {
20 ws_sink: Arc<Mutex<Option<WsSink>>>,
21 frames_tx: Sender<bytes::Bytes>,
22 on_disconnect: Arc<Mutex<Option<OnDisconnectCallback>>>,
23 is_connected: Arc<Mutex<bool>>,
24 header: Arc<Mutex<Option<Vec<u8>>>>,
25}
26
27impl FrameSocket {
28 pub fn new() -> (Self, Receiver<bytes::Bytes>) {
29 if let Err(e) = rustls::crypto::ring::default_provider().install_default() {
30 debug!("rustls crypto provider install: {:?}", e);
31 }
32
33 let (tx, rx) = mpsc::channel(100);
34 let socket = Self {
35 ws_sink: Arc::new(Mutex::new(None)),
36 frames_tx: tx,
37 on_disconnect: Arc::new(Mutex::new(None)),
38 is_connected: Arc::new(Mutex::new(false)),
39 header: Arc::new(Mutex::new(Some(WA_CONN_HEADER.to_vec()))),
40 };
41 (socket, rx)
42 }
43
44 pub async fn is_connected(&self) -> bool {
45 *self.is_connected.lock().await
46 }
47
48 pub async fn set_on_disconnect(&self, cb: OnDisconnectCallback) {
49 *self.on_disconnect.lock().await = Some(cb);
50 }
51
52 pub async fn connect(&self) -> Result<()> {
53 if self.is_connected().await {
54 return Err(SocketError::SocketAlreadyOpen);
55 }
56
57 info!("Dialing {URL}");
58 let uri: http::Uri = URL.parse().expect("Failed to parse URL");
59 let (client, _response) = match ClientBuilder::from_uri(uri).connect().await {
60 Ok(ok) => ok,
61 Err(e) => {
62 error!("WebSocket connect failed: {e:?}");
63 return Err(SocketError::WebSocket(e));
64 }
65 };
66
67 let (sink, stream) = client.split();
68 *self.ws_sink.lock().await = Some(sink);
69 *self.is_connected.lock().await = true;
70
71 let frames_tx_clone = self.frames_tx.clone();
72 let is_connected_clone = self.is_connected.clone();
73 let on_disconnect_clone = self.on_disconnect.clone();
74
75 tokio::task::spawn(Self::read_pump(
76 stream,
77 frames_tx_clone,
78 is_connected_clone,
79 on_disconnect_clone,
80 ));
81
82 Ok(())
83 }
84
85 pub async fn send_frame(&self, mut data: Vec<u8>) -> Result<()> {
86 let mut sink_guard = self.ws_sink.lock().await;
87 let sink = sink_guard.as_mut().ok_or(SocketError::SocketClosed)?;
88
89 let data_len = data.len();
90 if data_len >= FRAME_MAX_SIZE {
91 return Err(SocketError::FrameTooLarge {
92 max: FRAME_MAX_SIZE,
93 got: data_len,
94 });
95 }
96
97 let frame_header = self.header.lock().await.take().unwrap_or_default();
99 let header_len = frame_header.len();
100 let prefix_len = header_len + FRAME_LENGTH_SIZE;
101
102 data.reserve(prefix_len);
103 let original_len = data.len();
104 data.resize(original_len + prefix_len, 0);
105 data.copy_within(0..original_len, prefix_len);
106
107 if header_len > 0 {
109 data[0..header_len].copy_from_slice(&frame_header);
110 }
111 let len_bytes = u32::to_be_bytes(data_len as u32);
112 data[header_len..prefix_len].copy_from_slice(&len_bytes[1..]);
113
114 debug!(
115 "--> Sending frame: payload {} bytes, total {} bytes",
116 data_len,
117 data.len()
118 );
119 sink.send(Message::binary(data)).await?;
120 Ok(())
121 }
122
123 async fn read_pump(
124 mut stream: WsStream,
125 frames_tx: mpsc::Sender<bytes::Bytes>,
126 is_connected: Arc<Mutex<bool>>,
127 on_disconnect: Arc<Mutex<Option<OnDisconnectCallback>>>,
128 ) {
129 let mut buffer = BytesMut::new();
130
131 loop {
132 match stream.next().await {
133 Some(Ok(msg)) => {
134 if msg.is_binary() {
135 let data = msg.as_payload();
136 debug!("<-- Received WebSocket message: {} bytes", data.len());
137 buffer.extend_from_slice(data);
138
139 while buffer.len() >= FRAME_LENGTH_SIZE {
140 let frame_len = ((buffer[0] as usize) << 16)
141 | ((buffer[1] as usize) << 8)
142 | (buffer[2] as usize);
143
144 if buffer.len() >= FRAME_LENGTH_SIZE + frame_len {
145 buffer.advance(FRAME_LENGTH_SIZE);
146 let frame_data = buffer.split_to(frame_len).freeze();
147 trace!("<-- Assembled frame: {} bytes", frame_data.len());
148 if frames_tx.send(frame_data).await.is_err() {
149 warn!("Frame receiver dropped, closing read pump");
150 break;
151 }
152 } else {
153 break;
154 }
155 }
156 } else if msg.is_close() {
157 trace!("Received close frame");
158 break;
159 }
160 }
161 Some(Err(e)) => {
162 error!("Error reading from websocket: {e}");
163 break;
164 }
165 None => {
166 trace!("Websocket stream ended");
167 break;
168 }
169 }
170 }
171
172 *is_connected.lock().await = false;
173 if let Some(cb) = on_disconnect.lock().await.as_ref() {
174 (cb)(true);
175 }
176 }
177
178 pub async fn close(&self) {
179 let mut is_connected = self.is_connected.lock().await;
180 if *is_connected {
181 *is_connected = false;
182 *self.ws_sink.lock().await = None;
183 if let Some(cb) = self.on_disconnect.lock().await.as_ref() {
184 (cb)(false);
185 }
186 }
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[test]
195 fn test_prefix_shift_algorithm() {
196 let header = WA_CONN_HEADER.to_vec();
197 let header_len = header.len();
198 let payload: Vec<u8> = vec![1, 2, 3, 4, 5];
199 let mut data = payload.clone();
200 let prefix_len = header_len + FRAME_LENGTH_SIZE;
201 let original_len = data.len();
202 data.reserve(prefix_len);
203 data.resize(original_len + prefix_len, 0);
204 data.copy_within(0..original_len, prefix_len);
205 if header_len > 0 {
206 data[0..header_len].copy_from_slice(&header);
207 }
208 let len_bytes = u32::to_be_bytes(original_len as u32);
209 data[header_len..prefix_len].copy_from_slice(&len_bytes[1..]);
210
211 assert_eq!(&data[0..header_len], &header[..]);
212 let reported_len = ((data[header_len] as usize) << 16)
213 | ((data[header_len + 1] as usize) << 8)
214 | (data[header_len + 2] as usize);
215 assert_eq!(reported_len, original_len);
216 assert_eq!(&data[prefix_len..prefix_len + original_len], &payload[..]);
217 }
218}