Skip to main content

whatsapp_rust/
handshake.rs

1use crate::socket::NoiseSocket;
2use crate::transport::{Transport, TransportEvent};
3use log::{debug, info, warn};
4use prost::Message;
5use std::sync::Arc;
6use std::time::Duration;
7use thiserror::Error;
8use wacore::handshake::{
9    HandshakeError as CoreHandshakeError, HandshakeState, build_handshake_header,
10};
11use wacore::runtime::{Runtime, timeout as rt_timeout};
12use wacore_binary::consts::{NOISE_START_PATTERN, WA_CONN_HEADER};
13
14const NOISE_HANDSHAKE_RESPONSE_TIMEOUT: Duration = Duration::from_secs(20);
15
16#[derive(Debug, Error)]
17pub enum HandshakeError {
18    #[error("Transport error: {0}")]
19    Transport(#[from] anyhow::Error),
20    #[error("Core handshake error: {0}")]
21    Core(#[from] CoreHandshakeError),
22    #[error("Timed out waiting for handshake response")]
23    Timeout,
24    #[error("Unexpected event during handshake: {0}")]
25    UnexpectedEvent(String),
26}
27
28type Result<T> = std::result::Result<T, HandshakeError>;
29
30pub async fn do_handshake(
31    runtime: Arc<dyn Runtime>,
32    device: &crate::store::Device,
33    transport: Arc<dyn Transport>,
34    transport_events: &mut async_channel::Receiver<TransportEvent>,
35) -> Result<Arc<NoiseSocket>> {
36    // Prepare the client payload (convert Device-specific data to bytes)
37    let client_payload = device.core.get_client_payload().encode_to_vec();
38
39    let mut handshake_state = HandshakeState::new(
40        device.core.noise_key.clone(),
41        client_payload,
42        NOISE_START_PATTERN,
43        &WA_CONN_HEADER,
44    )?;
45    let mut frame_decoder = wacore::framing::FrameDecoder::new();
46
47    debug!("--> Sending ClientHello");
48    let client_hello_bytes = handshake_state.build_client_hello()?;
49
50    // Build the connection header, optionally with edge routing pre-intro
51    let (header, used_edge_routing) =
52        build_handshake_header(device.core.edge_routing_info.as_deref());
53    if used_edge_routing {
54        debug!("Sending edge routing pre-intro for optimized reconnection");
55    } else if device.core.edge_routing_info.is_some() {
56        warn!("Edge routing info provided but not used (possibly too large)");
57    }
58
59    // First message includes the WA connection header (with optional edge routing)
60    let framed = wacore::framing::encode_frame(&client_hello_bytes, Some(&header))
61        .map_err(HandshakeError::Transport)?;
62    transport.send(framed).await?;
63
64    // Wait for server response frame
65    let resp_frame = loop {
66        match rt_timeout(
67            &*runtime,
68            NOISE_HANDSHAKE_RESPONSE_TIMEOUT,
69            transport_events.recv(),
70        )
71        .await
72        {
73            Ok(Ok(TransportEvent::DataReceived(data))) => {
74                // Feed data into decoder
75                frame_decoder.feed(&data);
76
77                // Try to decode a frame
78                if let Some(frame) = frame_decoder.decode_frame() {
79                    break frame;
80                }
81                // If no complete frame yet, continue waiting for more data
82                continue;
83            }
84            Ok(Ok(TransportEvent::Connected)) => {
85                // Ignore Connected event, we're already connected
86                continue;
87            }
88            Ok(Ok(TransportEvent::Disconnected)) => {
89                return Err(HandshakeError::UnexpectedEvent(
90                    "Disconnected during handshake".to_string(),
91                ));
92            }
93            Ok(Err(_)) => return Err(HandshakeError::Timeout), // Channel closed
94            Err(_) => return Err(HandshakeError::Timeout),
95        }
96    };
97
98    debug!("<-- Received handshake response, building ClientFinish");
99    let client_finish_bytes =
100        handshake_state.read_server_hello_and_build_client_finish(&resp_frame)?;
101
102    debug!("--> Sending ClientFinish");
103    // Subsequent messages don't need the header
104    let framed = wacore::framing::encode_frame(&client_finish_bytes, None)
105        .map_err(HandshakeError::Transport)?;
106    transport.send(framed).await?;
107
108    let (write_key, read_key) = handshake_state.finish()?;
109    info!("Handshake complete, switching to encrypted communication");
110
111    Ok(Arc::new(NoiseSocket::new(
112        runtime, transport, write_key, read_key,
113    )))
114}