Skip to main content

vox_core/
handshake.rs

1use vox_types::{
2    ConnectionSettings, HandshakeMessage, HandshakeResult, LinkRx, LinkTx, LinkTxPermit,
3    ResumeKeyBytes, Schema, SessionResumeKey, SessionRole, WriteSlot,
4};
5
6#[derive(Debug)]
7pub enum HandshakeError {
8    Io(std::io::Error),
9    Encode(String),
10    Decode(String),
11    PeerClosed,
12    Protocol(String),
13    Sorry(String),
14    NotResumable,
15}
16
17impl std::fmt::Display for HandshakeError {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        match self {
20            Self::Io(e) => write!(f, "handshake io error: {e}"),
21            Self::Encode(e) => write!(f, "handshake encode error: {e}"),
22            Self::Decode(e) => write!(f, "handshake decode error: {e}"),
23            Self::PeerClosed => write!(f, "peer closed during handshake"),
24            Self::Protocol(msg) => write!(f, "handshake protocol error: {msg}"),
25            Self::Sorry(reason) => write!(f, "handshake rejected: {reason}"),
26            Self::NotResumable => write!(f, "session is not resumable"),
27        }
28    }
29}
30
31impl std::error::Error for HandshakeError {}
32
33/// Extract the Message schema from the static shape.
34fn message_schema() -> Vec<Schema> {
35    vox_types::extract_schemas(<vox_types::Message<'static> as facet::Facet<'static>>::SHAPE)
36        .expect("schema extraction")
37        .schemas
38}
39
40/// Send a CBOR-encoded handshake message on a raw link.
41async fn send_handshake<Tx: LinkTx>(tx: &Tx, msg: &HandshakeMessage) -> Result<(), HandshakeError> {
42    let bytes = facet_cbor::to_vec(msg).map_err(|e| HandshakeError::Encode(e.to_string()))?;
43    vox_types::dlog!(
44        "[handshake] send {:?} ({} bytes)",
45        handshake_tag(msg),
46        bytes.len()
47    );
48    let permit = tx.reserve().await.map_err(HandshakeError::Io)?;
49    let mut slot = permit.alloc(bytes.len()).map_err(HandshakeError::Io)?;
50    slot.as_mut_slice().copy_from_slice(&bytes);
51    slot.commit();
52    Ok(())
53}
54
55/// Receive and decode a CBOR handshake message from a raw link.
56async fn recv_handshake<Rx: LinkRx>(rx: &mut Rx) -> Result<HandshakeMessage, HandshakeError> {
57    let backing = rx
58        .recv()
59        .await
60        .map_err(|error| HandshakeError::Io(std::io::Error::other(error.to_string())))?
61        .ok_or(HandshakeError::PeerClosed)?;
62    vox_types::dlog!(
63        "[handshake] recv raw frame ({} bytes)",
64        backing.as_bytes().len()
65    );
66    let msg = facet_cbor::from_slice(backing.as_bytes())
67        .map_err(|e| HandshakeError::Decode(e.to_string()))?;
68    vox_types::dlog!("[handshake] recv {:?}", handshake_tag(&msg));
69    Ok(msg)
70}
71
72fn handshake_tag(msg: &HandshakeMessage) -> &'static str {
73    match msg {
74        HandshakeMessage::Hello(_) => "Hello",
75        HandshakeMessage::HelloYourself(_) => "HelloYourself",
76        HandshakeMessage::LetsGo(_) => "LetsGo",
77        HandshakeMessage::Sorry(_) => "Sorry",
78    }
79}
80
81// r[impl session.handshake]
82// r[impl session.handshake.cbor]
83/// Perform the CBOR handshake as the initiator.
84///
85/// Three-step exchange:
86/// 1. Send Hello
87/// 2. Receive HelloYourself (or Sorry)
88/// 3. Send LetsGo (or Sorry)
89pub async fn handshake_as_initiator<Tx: LinkTx, Rx: LinkRx>(
90    tx: &Tx,
91    rx: &mut Rx,
92    settings: ConnectionSettings,
93    supports_retry: bool,
94    resume_key: Option<&SessionResumeKey>,
95) -> Result<HandshakeResult, HandshakeError> {
96    let our_schema = message_schema();
97
98    let hello = vox_types::Hello {
99        parity: settings.parity,
100        connection_settings: settings.clone(),
101        message_payload_schema: our_schema.clone(),
102        supports_retry,
103        resume_key: resume_key.map(ResumeKeyBytes::from_key),
104    };
105
106    // Step 1: Send Hello
107    send_handshake(tx, &HandshakeMessage::Hello(hello)).await?;
108
109    // Step 2: Receive HelloYourself or Sorry
110    let response = recv_handshake(rx).await?;
111    let hy = match response {
112        HandshakeMessage::HelloYourself(hy) => hy,
113        HandshakeMessage::Sorry(sorry) => return Err(HandshakeError::Sorry(sorry.reason)),
114        _ => {
115            return Err(HandshakeError::Protocol(
116                "expected HelloYourself or Sorry".into(),
117            ));
118        }
119    };
120
121    // Step 3: Send LetsGo
122    // TODO: Compare schemas and send Sorry if incompatible
123    send_handshake(tx, &HandshakeMessage::LetsGo(vox_types::LetsGo {})).await?;
124
125    let session_resume_key = hy.resume_key.as_ref().and_then(|k| k.to_key());
126
127    Ok(HandshakeResult {
128        role: SessionRole::Initiator,
129        our_settings: settings,
130        peer_settings: hy.connection_settings,
131        peer_supports_retry: hy.supports_retry,
132        session_resume_key,
133        peer_resume_key: None, // initiator doesn't receive a peer resume key
134        our_schema,
135        peer_schema: hy.message_payload_schema,
136    })
137}
138
139// r[impl session.handshake]
140// r[impl session.handshake.cbor]
141/// Perform the CBOR handshake as the acceptor.
142///
143/// Three-step exchange:
144/// 1. Receive Hello
145/// 2. Send HelloYourself (or Sorry)
146/// 3. Receive LetsGo (or Sorry)
147pub async fn handshake_as_acceptor<Tx: LinkTx, Rx: LinkRx>(
148    tx: &Tx,
149    rx: &mut Rx,
150    settings: ConnectionSettings,
151    supports_retry: bool,
152    resumable: bool,
153    expected_resume_key: Option<&SessionResumeKey>,
154) -> Result<HandshakeResult, HandshakeError> {
155    // Step 1: Receive Hello
156    let hello = match recv_handshake(rx).await? {
157        HandshakeMessage::Hello(h) => h,
158        _ => return Err(HandshakeError::Protocol("expected Hello".into())),
159    };
160
161    // Validate resume key if this is a resumption attempt
162    if let Some(expected) = expected_resume_key {
163        let actual = hello.resume_key.as_ref().and_then(|k| k.to_key());
164        match actual {
165            Some(actual) if actual == *expected => {} // OK
166            _ => {
167                let reason = "session resume key mismatch".to_string();
168                send_handshake(
169                    tx,
170                    &HandshakeMessage::Sorry(vox_types::Sorry {
171                        reason: reason.clone(),
172                    }),
173                )
174                .await?;
175                return Err(HandshakeError::Protocol(reason));
176            }
177        }
178    }
179
180    // Acceptor adopts opposite parity
181    let our_settings = ConnectionSettings {
182        parity: hello.parity.other(),
183        ..settings
184    };
185
186    // Generate resume key if we're resumable
187    let our_resume_key = if resumable {
188        Some(fresh_resume_key()?)
189    } else {
190        None
191    };
192
193    let our_schema = message_schema();
194
195    // Step 2: Send HelloYourself
196    let hy = vox_types::HelloYourself {
197        connection_settings: our_settings.clone(),
198        message_payload_schema: our_schema.clone(),
199        supports_retry,
200        resume_key: our_resume_key.as_ref().map(ResumeKeyBytes::from_key),
201    };
202    send_handshake(tx, &HandshakeMessage::HelloYourself(hy)).await?;
203
204    // Step 3: Receive LetsGo or Sorry
205    let response = recv_handshake(rx).await?;
206    match response {
207        HandshakeMessage::LetsGo(_) => {}
208        HandshakeMessage::Sorry(sorry) => return Err(HandshakeError::Sorry(sorry.reason)),
209        _ => return Err(HandshakeError::Protocol("expected LetsGo or Sorry".into())),
210    }
211
212    let peer_resume_key = hello.resume_key.as_ref().and_then(|k| k.to_key());
213
214    Ok(HandshakeResult {
215        role: SessionRole::Acceptor,
216        our_settings,
217        peer_settings: hello.connection_settings,
218        peer_supports_retry: hello.supports_retry,
219        session_resume_key: our_resume_key,
220        peer_resume_key,
221        our_schema,
222        peer_schema: hello.message_payload_schema,
223    })
224}
225
226fn fresh_resume_key() -> Result<SessionResumeKey, HandshakeError> {
227    let mut bytes = [0u8; 16];
228    getrandom::fill(&mut bytes).map_err(|error| {
229        HandshakeError::Protocol(format!("failed to generate session key: {error}"))
230    })?;
231    Ok(SessionResumeKey(bytes))
232}