Skip to main content

vox_core/
handshake.rs

1use vox_types::{
2    ConnectionSettings, HandshakeMessage, HandshakeResult, LinkRx, LinkTx, ResumeKeyBytes, Schema,
3    SessionResumeKey, SessionRole,
4};
5
6const INITIAL_CHANNEL_CREDIT_ZERO_ERROR: &str = "initial_channel_credit must be greater than zero";
7
8#[derive(Debug)]
9pub enum HandshakeError {
10    Io(std::io::Error),
11    Encode(String),
12    Decode(String),
13    PeerClosed,
14    Protocol(String),
15    Sorry(String),
16    NotResumable,
17}
18
19impl std::fmt::Display for HandshakeError {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        match self {
22            Self::Io(e) => write!(f, "handshake io error: {e}"),
23            Self::Encode(e) => write!(f, "handshake encode error: {e}"),
24            Self::Decode(e) => write!(f, "handshake decode error: {e}"),
25            Self::PeerClosed => write!(f, "peer closed during handshake"),
26            Self::Protocol(msg) => write!(f, "handshake protocol error: {msg}"),
27            Self::Sorry(reason) => write!(f, "handshake rejected: {reason}"),
28            Self::NotResumable => write!(f, "session is not resumable"),
29        }
30    }
31}
32
33impl std::error::Error for HandshakeError {}
34
35// r[impl rpc.flow-control.credit.initial.zero]
36fn validate_initial_channel_credit(settings: &ConnectionSettings) -> Result<(), HandshakeError> {
37    if settings.initial_channel_credit == 0 {
38        return Err(HandshakeError::Protocol(
39            INITIAL_CHANNEL_CREDIT_ZERO_ERROR.into(),
40        ));
41    }
42    Ok(())
43}
44
45/// Extract the Message schema from the static shape.
46fn message_schema() -> Vec<Schema> {
47    vox_types::extract_schemas(<vox_types::Message<'static> as facet::Facet<'static>>::SHAPE)
48        .expect("schema extraction")
49        .schemas
50        .clone()
51}
52
53/// Send a CBOR-encoded handshake message on a raw link.
54async fn send_handshake<Tx: LinkTx>(tx: &Tx, msg: &HandshakeMessage) -> Result<(), HandshakeError> {
55    let bytes = facet_cbor::to_vec(msg).map_err(|e| HandshakeError::Encode(e.to_string()))?;
56    vox_types::dlog!(
57        "[handshake] send {:?} ({} bytes)",
58        handshake_tag(msg),
59        bytes.len()
60    );
61    tx.send(bytes).await.map_err(HandshakeError::Io)
62}
63
64/// Receive and decode a CBOR handshake message from a raw link.
65async fn recv_handshake<Rx: LinkRx>(rx: &mut Rx) -> Result<HandshakeMessage, HandshakeError> {
66    let backing = rx
67        .recv()
68        .await
69        .map_err(|error| HandshakeError::Io(std::io::Error::other(error.to_string())))?
70        .ok_or(HandshakeError::PeerClosed)?;
71    vox_types::dlog!(
72        "[handshake] recv raw frame ({} bytes)",
73        backing.as_bytes().len()
74    );
75    let msg = facet_cbor::from_slice(backing.as_bytes())
76        .map_err(|e| HandshakeError::Decode(e.to_string()))?;
77    vox_types::dlog!("[handshake] recv {:?}", handshake_tag(&msg));
78    Ok(msg)
79}
80
81fn handshake_tag(msg: &HandshakeMessage) -> &'static str {
82    match msg {
83        HandshakeMessage::Hello(_) => "Hello",
84        HandshakeMessage::HelloYourself(_) => "HelloYourself",
85        HandshakeMessage::LetsGo(_) => "LetsGo",
86        HandshakeMessage::Sorry(_) => "Sorry",
87    }
88}
89
90// r[impl session.handshake]
91// r[impl session.handshake.cbor]
92/// Perform the CBOR handshake as the initiator.
93///
94/// Three-step exchange:
95/// 1. Send Hello
96/// 2. Receive HelloYourself (or Sorry)
97/// 3. Send LetsGo (or Sorry)
98pub async fn handshake_as_initiator<Tx: LinkTx, Rx: LinkRx>(
99    tx: &Tx,
100    rx: &mut Rx,
101    settings: ConnectionSettings,
102    supports_retry: bool,
103    resume_key: Option<&SessionResumeKey>,
104    metadata: vox_types::Metadata<'static>,
105) -> Result<HandshakeResult, HandshakeError> {
106    validate_initial_channel_credit(&settings)?;
107
108    let our_schema = message_schema();
109
110    let hello = vox_types::Hello {
111        parity: settings.parity,
112        connection_settings: settings.clone(),
113        message_payload_schema: our_schema.clone(),
114        supports_retry,
115        resume_key: resume_key.map(ResumeKeyBytes::from_key),
116        metadata,
117    };
118
119    // Step 1: Send Hello
120    send_handshake(tx, &HandshakeMessage::Hello(hello)).await?;
121
122    // Step 2: Receive HelloYourself or Sorry
123    let response = recv_handshake(rx).await?;
124    let hy = match response {
125        HandshakeMessage::HelloYourself(hy) => hy,
126        HandshakeMessage::Sorry(sorry) => return Err(HandshakeError::Sorry(sorry.reason)),
127        _ => {
128            return Err(HandshakeError::Protocol(
129                "expected HelloYourself or Sorry".into(),
130            ));
131        }
132    };
133    if hy.connection_settings.initial_channel_credit == 0 {
134        let reason = INITIAL_CHANNEL_CREDIT_ZERO_ERROR.to_string();
135        send_handshake(
136            tx,
137            &HandshakeMessage::Sorry(vox_types::Sorry {
138                reason: reason.clone(),
139            }),
140        )
141        .await?;
142        return Err(HandshakeError::Protocol(reason));
143    }
144
145    // Step 3: Send LetsGo
146    // TODO: Compare schemas and send Sorry if incompatible
147    send_handshake(tx, &HandshakeMessage::LetsGo(vox_types::LetsGo {})).await?;
148
149    let session_resume_key = hy.resume_key.as_ref().and_then(|k| k.to_key());
150
151    Ok(HandshakeResult {
152        role: SessionRole::Initiator,
153        our_settings: settings,
154        peer_settings: hy.connection_settings,
155        peer_supports_retry: hy.supports_retry,
156        session_resume_key,
157        peer_resume_key: None, // initiator doesn't receive a peer resume key
158        our_schema,
159        peer_schema: hy.message_payload_schema,
160        peer_metadata: hy.metadata,
161    })
162}
163
164// r[impl session.handshake]
165// r[impl session.handshake.cbor]
166/// Perform the CBOR handshake as the acceptor.
167///
168/// Three-step exchange:
169/// 1. Receive Hello
170/// 2. Send HelloYourself (or Sorry)
171/// 3. Receive LetsGo (or Sorry)
172pub async fn handshake_as_acceptor<Tx: LinkTx, Rx: LinkRx>(
173    tx: &Tx,
174    rx: &mut Rx,
175    settings: ConnectionSettings,
176    supports_retry: bool,
177    resumable: bool,
178    expected_resume_key: Option<&SessionResumeKey>,
179    metadata: vox_types::Metadata<'static>,
180) -> Result<HandshakeResult, HandshakeError> {
181    validate_initial_channel_credit(&settings)?;
182
183    // Step 1: Receive Hello
184    let hello = match recv_handshake(rx).await? {
185        HandshakeMessage::Hello(h) => h,
186        _ => return Err(HandshakeError::Protocol("expected Hello".into())),
187    };
188    if hello.connection_settings.initial_channel_credit == 0 {
189        let reason = INITIAL_CHANNEL_CREDIT_ZERO_ERROR.to_string();
190        send_handshake(
191            tx,
192            &HandshakeMessage::Sorry(vox_types::Sorry {
193                reason: reason.clone(),
194            }),
195        )
196        .await?;
197        return Err(HandshakeError::Protocol(reason));
198    }
199
200    // Validate resume key if this is a resumption attempt
201    if let Some(expected) = expected_resume_key {
202        let actual = hello.resume_key.as_ref().and_then(|k| k.to_key());
203        match actual {
204            Some(actual) if actual == *expected => {} // OK
205            _ => {
206                let reason = "session resume key mismatch".to_string();
207                send_handshake(
208                    tx,
209                    &HandshakeMessage::Sorry(vox_types::Sorry {
210                        reason: reason.clone(),
211                    }),
212                )
213                .await?;
214                return Err(HandshakeError::Protocol(reason));
215            }
216        }
217    }
218
219    // Acceptor adopts opposite parity
220    let our_settings = ConnectionSettings {
221        parity: hello.parity.other(),
222        ..settings
223    };
224
225    // Generate resume key if we're resumable
226    let our_resume_key = if resumable {
227        Some(fresh_resume_key()?)
228    } else {
229        None
230    };
231
232    let our_schema = message_schema();
233
234    // Step 2: Send HelloYourself
235    let hy = vox_types::HelloYourself {
236        connection_settings: our_settings.clone(),
237        message_payload_schema: our_schema.clone(),
238        supports_retry,
239        resume_key: our_resume_key.as_ref().map(ResumeKeyBytes::from_key),
240        metadata,
241    };
242    send_handshake(tx, &HandshakeMessage::HelloYourself(hy)).await?;
243
244    // Step 3: Receive LetsGo or Sorry
245    let response = recv_handshake(rx).await?;
246    match response {
247        HandshakeMessage::LetsGo(_) => {}
248        HandshakeMessage::Sorry(sorry) => return Err(HandshakeError::Sorry(sorry.reason)),
249        _ => return Err(HandshakeError::Protocol("expected LetsGo or Sorry".into())),
250    }
251
252    let peer_resume_key = hello.resume_key.as_ref().and_then(|k| k.to_key());
253
254    Ok(HandshakeResult {
255        role: SessionRole::Acceptor,
256        our_settings,
257        peer_settings: hello.connection_settings,
258        peer_supports_retry: hello.supports_retry,
259        session_resume_key: our_resume_key,
260        peer_resume_key,
261        our_schema,
262        peer_schema: hello.message_payload_schema,
263        peer_metadata: hello.metadata,
264    })
265}
266
267fn fresh_resume_key() -> Result<SessionResumeKey, HandshakeError> {
268    let mut bytes = [0u8; 16];
269    getrandom::fill(&mut bytes).map_err(|error| {
270        HandshakeError::Protocol(format!("failed to generate session key: {error}"))
271    })?;
272    Ok(SessionResumeKey(bytes))
273}
274
275#[cfg(test)]
276mod tests {
277    use vox_types::{Link, Parity};
278
279    use super::*;
280
281    fn settings(parity: Parity, initial_channel_credit: u32) -> ConnectionSettings {
282        ConnectionSettings {
283            parity,
284            max_concurrent_requests: 64,
285            initial_channel_credit,
286        }
287    }
288
289    // r[verify rpc.flow-control.credit.initial.zero]
290    #[tokio::test]
291    async fn initiator_rejects_local_zero_initial_credit_before_handshake() {
292        let (link, _peer) = crate::memory_link_pair(1);
293        let (tx, mut rx) = link.split();
294
295        let result =
296            handshake_as_initiator(&tx, &mut rx, settings(Parity::Odd, 0), true, None, vec![])
297                .await;
298
299        assert!(
300            matches!(
301                result,
302                Err(HandshakeError::Protocol(ref message))
303                    if message == INITIAL_CHANNEL_CREDIT_ZERO_ERROR
304            ),
305            "expected zero-credit protocol error, got: {result:?}"
306        );
307    }
308
309    // r[verify rpc.flow-control.credit.initial.zero]
310    #[tokio::test]
311    async fn acceptor_rejects_peer_zero_initial_credit_before_session_starts() {
312        let (client_link, server_link) = crate::memory_link_pair(4);
313        let (client_tx, mut client_rx) = client_link.split();
314        let (server_tx, mut server_rx) = server_link.split();
315
316        let acceptor = tokio::spawn(async move {
317            handshake_as_acceptor(
318                &server_tx,
319                &mut server_rx,
320                settings(Parity::Even, 16),
321                true,
322                false,
323                None,
324                vec![],
325            )
326            .await
327        });
328
329        send_handshake(
330            &client_tx,
331            &HandshakeMessage::Hello(vox_types::Hello {
332                parity: Parity::Odd,
333                connection_settings: settings(Parity::Odd, 0),
334                message_payload_schema: message_schema(),
335                supports_retry: true,
336                resume_key: None,
337                metadata: vec![],
338            }),
339        )
340        .await
341        .expect("send hello");
342
343        let response = recv_handshake(&mut client_rx).await.expect("recv sorry");
344        assert!(
345            matches!(
346                response,
347                HandshakeMessage::Sorry(vox_types::Sorry { ref reason })
348                    if reason == INITIAL_CHANNEL_CREDIT_ZERO_ERROR
349            ),
350            "expected Sorry for zero credit, got: {response:?}"
351        );
352
353        let result = acceptor.await.expect("acceptor task");
354        assert!(
355            matches!(
356                result,
357                Err(HandshakeError::Protocol(ref message))
358                    if message == INITIAL_CHANNEL_CREDIT_ZERO_ERROR
359            ),
360            "expected zero-credit protocol error, got: {result:?}"
361        );
362    }
363}