1use vox_types::{
2 ConnectionSettings, HandshakeMessage, HandshakeResult, LinkRx, LinkTx, ResumeKeyBytes, Schema,
3 SessionResumeKey, SessionRole,
4};
5
6const INITIAL_CHANNEL_CREDIT_ZERO_ERROR: &str = "initial_channel_credit must be greater than zero";
7
8#[derive(Debug)]
9pub enum HandshakeError {
10 Io(std::io::Error),
11 Encode(String),
12 Decode(String),
13 PeerClosed,
14 Protocol(String),
15 Sorry(String),
16 NotResumable,
17}
18
19impl std::fmt::Display for HandshakeError {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 match self {
22 Self::Io(e) => write!(f, "handshake io error: {e}"),
23 Self::Encode(e) => write!(f, "handshake encode error: {e}"),
24 Self::Decode(e) => write!(f, "handshake decode error: {e}"),
25 Self::PeerClosed => write!(f, "peer closed during handshake"),
26 Self::Protocol(msg) => write!(f, "handshake protocol error: {msg}"),
27 Self::Sorry(reason) => write!(f, "handshake rejected: {reason}"),
28 Self::NotResumable => write!(f, "session is not resumable"),
29 }
30 }
31}
32
33impl std::error::Error for HandshakeError {}
34
35fn validate_initial_channel_credit(settings: &ConnectionSettings) -> Result<(), HandshakeError> {
37 if settings.initial_channel_credit == 0 {
38 return Err(HandshakeError::Protocol(
39 INITIAL_CHANNEL_CREDIT_ZERO_ERROR.into(),
40 ));
41 }
42 Ok(())
43}
44
45fn message_schema() -> Vec<Schema> {
47 vox_types::extract_schemas(<vox_types::Message<'static> as facet::Facet<'static>>::SHAPE)
48 .expect("schema extraction")
49 .schemas
50 .clone()
51}
52
53async fn send_handshake<Tx: LinkTx>(tx: &Tx, msg: &HandshakeMessage) -> Result<(), HandshakeError> {
55 let bytes = facet_cbor::to_vec(msg).map_err(|e| HandshakeError::Encode(e.to_string()))?;
56 vox_types::dlog!(
57 "[handshake] send {:?} ({} bytes)",
58 handshake_tag(msg),
59 bytes.len()
60 );
61 tx.send(bytes).await.map_err(HandshakeError::Io)
62}
63
64async fn recv_handshake<Rx: LinkRx>(rx: &mut Rx) -> Result<HandshakeMessage, HandshakeError> {
66 let backing = rx
67 .recv()
68 .await
69 .map_err(|error| HandshakeError::Io(std::io::Error::other(error.to_string())))?
70 .ok_or(HandshakeError::PeerClosed)?;
71 vox_types::dlog!(
72 "[handshake] recv raw frame ({} bytes)",
73 backing.as_bytes().len()
74 );
75 let msg = facet_cbor::from_slice(backing.as_bytes())
76 .map_err(|e| HandshakeError::Decode(e.to_string()))?;
77 vox_types::dlog!("[handshake] recv {:?}", handshake_tag(&msg));
78 Ok(msg)
79}
80
81fn handshake_tag(msg: &HandshakeMessage) -> &'static str {
82 match msg {
83 HandshakeMessage::Hello(_) => "Hello",
84 HandshakeMessage::HelloYourself(_) => "HelloYourself",
85 HandshakeMessage::LetsGo(_) => "LetsGo",
86 HandshakeMessage::Sorry(_) => "Sorry",
87 }
88}
89
90pub async fn handshake_as_initiator<Tx: LinkTx, Rx: LinkRx>(
99 tx: &Tx,
100 rx: &mut Rx,
101 settings: ConnectionSettings,
102 supports_retry: bool,
103 resume_key: Option<&SessionResumeKey>,
104 metadata: vox_types::Metadata<'static>,
105) -> Result<HandshakeResult, HandshakeError> {
106 validate_initial_channel_credit(&settings)?;
107
108 let our_schema = message_schema();
109
110 let hello = vox_types::Hello {
111 parity: settings.parity,
112 connection_settings: settings.clone(),
113 message_payload_schema: our_schema.clone(),
114 supports_retry,
115 resume_key: resume_key.map(ResumeKeyBytes::from_key),
116 metadata,
117 };
118
119 send_handshake(tx, &HandshakeMessage::Hello(hello)).await?;
121
122 let response = recv_handshake(rx).await?;
124 let hy = match response {
125 HandshakeMessage::HelloYourself(hy) => hy,
126 HandshakeMessage::Sorry(sorry) => return Err(HandshakeError::Sorry(sorry.reason)),
127 _ => {
128 return Err(HandshakeError::Protocol(
129 "expected HelloYourself or Sorry".into(),
130 ));
131 }
132 };
133 if hy.connection_settings.initial_channel_credit == 0 {
134 let reason = INITIAL_CHANNEL_CREDIT_ZERO_ERROR.to_string();
135 send_handshake(
136 tx,
137 &HandshakeMessage::Sorry(vox_types::Sorry {
138 reason: reason.clone(),
139 }),
140 )
141 .await?;
142 return Err(HandshakeError::Protocol(reason));
143 }
144
145 send_handshake(tx, &HandshakeMessage::LetsGo(vox_types::LetsGo {})).await?;
148
149 let session_resume_key = hy.resume_key.as_ref().and_then(|k| k.to_key());
150
151 Ok(HandshakeResult {
152 role: SessionRole::Initiator,
153 our_settings: settings,
154 peer_settings: hy.connection_settings,
155 peer_supports_retry: hy.supports_retry,
156 session_resume_key,
157 peer_resume_key: None, our_schema,
159 peer_schema: hy.message_payload_schema,
160 peer_metadata: hy.metadata,
161 })
162}
163
164pub async fn handshake_as_acceptor<Tx: LinkTx, Rx: LinkRx>(
173 tx: &Tx,
174 rx: &mut Rx,
175 settings: ConnectionSettings,
176 supports_retry: bool,
177 resumable: bool,
178 expected_resume_key: Option<&SessionResumeKey>,
179 metadata: vox_types::Metadata<'static>,
180) -> Result<HandshakeResult, HandshakeError> {
181 validate_initial_channel_credit(&settings)?;
182
183 let hello = match recv_handshake(rx).await? {
185 HandshakeMessage::Hello(h) => h,
186 _ => return Err(HandshakeError::Protocol("expected Hello".into())),
187 };
188 if hello.connection_settings.initial_channel_credit == 0 {
189 let reason = INITIAL_CHANNEL_CREDIT_ZERO_ERROR.to_string();
190 send_handshake(
191 tx,
192 &HandshakeMessage::Sorry(vox_types::Sorry {
193 reason: reason.clone(),
194 }),
195 )
196 .await?;
197 return Err(HandshakeError::Protocol(reason));
198 }
199
200 if let Some(expected) = expected_resume_key {
202 let actual = hello.resume_key.as_ref().and_then(|k| k.to_key());
203 match actual {
204 Some(actual) if actual == *expected => {} _ => {
206 let reason = "session resume key mismatch".to_string();
207 send_handshake(
208 tx,
209 &HandshakeMessage::Sorry(vox_types::Sorry {
210 reason: reason.clone(),
211 }),
212 )
213 .await?;
214 return Err(HandshakeError::Protocol(reason));
215 }
216 }
217 }
218
219 let our_settings = ConnectionSettings {
221 parity: hello.parity.other(),
222 ..settings
223 };
224
225 let our_resume_key = if resumable {
227 Some(fresh_resume_key()?)
228 } else {
229 None
230 };
231
232 let our_schema = message_schema();
233
234 let hy = vox_types::HelloYourself {
236 connection_settings: our_settings.clone(),
237 message_payload_schema: our_schema.clone(),
238 supports_retry,
239 resume_key: our_resume_key.as_ref().map(ResumeKeyBytes::from_key),
240 metadata,
241 };
242 send_handshake(tx, &HandshakeMessage::HelloYourself(hy)).await?;
243
244 let response = recv_handshake(rx).await?;
246 match response {
247 HandshakeMessage::LetsGo(_) => {}
248 HandshakeMessage::Sorry(sorry) => return Err(HandshakeError::Sorry(sorry.reason)),
249 _ => return Err(HandshakeError::Protocol("expected LetsGo or Sorry".into())),
250 }
251
252 let peer_resume_key = hello.resume_key.as_ref().and_then(|k| k.to_key());
253
254 Ok(HandshakeResult {
255 role: SessionRole::Acceptor,
256 our_settings,
257 peer_settings: hello.connection_settings,
258 peer_supports_retry: hello.supports_retry,
259 session_resume_key: our_resume_key,
260 peer_resume_key,
261 our_schema,
262 peer_schema: hello.message_payload_schema,
263 peer_metadata: hello.metadata,
264 })
265}
266
267fn fresh_resume_key() -> Result<SessionResumeKey, HandshakeError> {
268 let mut bytes = [0u8; 16];
269 getrandom::fill(&mut bytes).map_err(|error| {
270 HandshakeError::Protocol(format!("failed to generate session key: {error}"))
271 })?;
272 Ok(SessionResumeKey(bytes))
273}
274
275#[cfg(test)]
276mod tests {
277 use vox_types::{Link, Parity};
278
279 use super::*;
280
281 fn settings(parity: Parity, initial_channel_credit: u32) -> ConnectionSettings {
282 ConnectionSettings {
283 parity,
284 max_concurrent_requests: 64,
285 initial_channel_credit,
286 }
287 }
288
289 #[tokio::test]
291 async fn initiator_rejects_local_zero_initial_credit_before_handshake() {
292 let (link, _peer) = crate::memory_link_pair(1);
293 let (tx, mut rx) = link.split();
294
295 let result =
296 handshake_as_initiator(&tx, &mut rx, settings(Parity::Odd, 0), true, None, vec![])
297 .await;
298
299 assert!(
300 matches!(
301 result,
302 Err(HandshakeError::Protocol(ref message))
303 if message == INITIAL_CHANNEL_CREDIT_ZERO_ERROR
304 ),
305 "expected zero-credit protocol error, got: {result:?}"
306 );
307 }
308
309 #[tokio::test]
311 async fn acceptor_rejects_peer_zero_initial_credit_before_session_starts() {
312 let (client_link, server_link) = crate::memory_link_pair(4);
313 let (client_tx, mut client_rx) = client_link.split();
314 let (server_tx, mut server_rx) = server_link.split();
315
316 let acceptor = tokio::spawn(async move {
317 handshake_as_acceptor(
318 &server_tx,
319 &mut server_rx,
320 settings(Parity::Even, 16),
321 true,
322 false,
323 None,
324 vec![],
325 )
326 .await
327 });
328
329 send_handshake(
330 &client_tx,
331 &HandshakeMessage::Hello(vox_types::Hello {
332 parity: Parity::Odd,
333 connection_settings: settings(Parity::Odd, 0),
334 message_payload_schema: message_schema(),
335 supports_retry: true,
336 resume_key: None,
337 metadata: vec![],
338 }),
339 )
340 .await
341 .expect("send hello");
342
343 let response = recv_handshake(&mut client_rx).await.expect("recv sorry");
344 assert!(
345 matches!(
346 response,
347 HandshakeMessage::Sorry(vox_types::Sorry { ref reason })
348 if reason == INITIAL_CHANNEL_CREDIT_ZERO_ERROR
349 ),
350 "expected Sorry for zero credit, got: {response:?}"
351 );
352
353 let result = acceptor.await.expect("acceptor task");
354 assert!(
355 matches!(
356 result,
357 Err(HandshakeError::Protocol(ref message))
358 if message == INITIAL_CHANNEL_CREDIT_ZERO_ERROR
359 ),
360 "expected zero-credit protocol error, got: {result:?}"
361 );
362 }
363}