1use vox_types::{
2 ConnectionSettings, HandshakeMessage, HandshakeResult, LinkRx, LinkTx, LinkTxPermit,
3 ResumeKeyBytes, Schema, SessionResumeKey, SessionRole, WriteSlot,
4};
5
6#[derive(Debug)]
7pub enum HandshakeError {
8 Io(std::io::Error),
9 Encode(String),
10 Decode(String),
11 PeerClosed,
12 Protocol(String),
13 Sorry(String),
14 NotResumable,
15}
16
17impl std::fmt::Display for HandshakeError {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 match self {
20 Self::Io(e) => write!(f, "handshake io error: {e}"),
21 Self::Encode(e) => write!(f, "handshake encode error: {e}"),
22 Self::Decode(e) => write!(f, "handshake decode error: {e}"),
23 Self::PeerClosed => write!(f, "peer closed during handshake"),
24 Self::Protocol(msg) => write!(f, "handshake protocol error: {msg}"),
25 Self::Sorry(reason) => write!(f, "handshake rejected: {reason}"),
26 Self::NotResumable => write!(f, "session is not resumable"),
27 }
28 }
29}
30
31impl std::error::Error for HandshakeError {}
32
33fn message_schema() -> Vec<Schema> {
35 vox_types::extract_schemas(<vox_types::Message<'static> as facet::Facet<'static>>::SHAPE)
36 .expect("schema extraction")
37 .schemas
38}
39
40async fn send_handshake<Tx: LinkTx>(tx: &Tx, msg: &HandshakeMessage) -> Result<(), HandshakeError> {
42 let bytes = facet_cbor::to_vec(msg).map_err(|e| HandshakeError::Encode(e.to_string()))?;
43 vox_types::dlog!(
44 "[handshake] send {:?} ({} bytes)",
45 handshake_tag(msg),
46 bytes.len()
47 );
48 let permit = tx.reserve().await.map_err(HandshakeError::Io)?;
49 let mut slot = permit.alloc(bytes.len()).map_err(HandshakeError::Io)?;
50 slot.as_mut_slice().copy_from_slice(&bytes);
51 slot.commit();
52 Ok(())
53}
54
55async fn recv_handshake<Rx: LinkRx>(rx: &mut Rx) -> Result<HandshakeMessage, HandshakeError> {
57 let backing = rx
58 .recv()
59 .await
60 .map_err(|error| HandshakeError::Io(std::io::Error::other(error.to_string())))?
61 .ok_or(HandshakeError::PeerClosed)?;
62 vox_types::dlog!(
63 "[handshake] recv raw frame ({} bytes)",
64 backing.as_bytes().len()
65 );
66 let msg = facet_cbor::from_slice(backing.as_bytes())
67 .map_err(|e| HandshakeError::Decode(e.to_string()))?;
68 vox_types::dlog!("[handshake] recv {:?}", handshake_tag(&msg));
69 Ok(msg)
70}
71
72fn handshake_tag(msg: &HandshakeMessage) -> &'static str {
73 match msg {
74 HandshakeMessage::Hello(_) => "Hello",
75 HandshakeMessage::HelloYourself(_) => "HelloYourself",
76 HandshakeMessage::LetsGo(_) => "LetsGo",
77 HandshakeMessage::Sorry(_) => "Sorry",
78 }
79}
80
81pub async fn handshake_as_initiator<Tx: LinkTx, Rx: LinkRx>(
90 tx: &Tx,
91 rx: &mut Rx,
92 settings: ConnectionSettings,
93 supports_retry: bool,
94 resume_key: Option<&SessionResumeKey>,
95) -> Result<HandshakeResult, HandshakeError> {
96 let our_schema = message_schema();
97
98 let hello = vox_types::Hello {
99 parity: settings.parity,
100 connection_settings: settings.clone(),
101 message_payload_schema: our_schema.clone(),
102 supports_retry,
103 resume_key: resume_key.map(ResumeKeyBytes::from_key),
104 };
105
106 send_handshake(tx, &HandshakeMessage::Hello(hello)).await?;
108
109 let response = recv_handshake(rx).await?;
111 let hy = match response {
112 HandshakeMessage::HelloYourself(hy) => hy,
113 HandshakeMessage::Sorry(sorry) => return Err(HandshakeError::Sorry(sorry.reason)),
114 _ => {
115 return Err(HandshakeError::Protocol(
116 "expected HelloYourself or Sorry".into(),
117 ));
118 }
119 };
120
121 send_handshake(tx, &HandshakeMessage::LetsGo(vox_types::LetsGo {})).await?;
124
125 let session_resume_key = hy.resume_key.as_ref().and_then(|k| k.to_key());
126
127 Ok(HandshakeResult {
128 role: SessionRole::Initiator,
129 our_settings: settings,
130 peer_settings: hy.connection_settings,
131 peer_supports_retry: hy.supports_retry,
132 session_resume_key,
133 peer_resume_key: None, our_schema,
135 peer_schema: hy.message_payload_schema,
136 })
137}
138
139pub async fn handshake_as_acceptor<Tx: LinkTx, Rx: LinkRx>(
148 tx: &Tx,
149 rx: &mut Rx,
150 settings: ConnectionSettings,
151 supports_retry: bool,
152 resumable: bool,
153 expected_resume_key: Option<&SessionResumeKey>,
154) -> Result<HandshakeResult, HandshakeError> {
155 let hello = match recv_handshake(rx).await? {
157 HandshakeMessage::Hello(h) => h,
158 _ => return Err(HandshakeError::Protocol("expected Hello".into())),
159 };
160
161 if let Some(expected) = expected_resume_key {
163 let actual = hello.resume_key.as_ref().and_then(|k| k.to_key());
164 match actual {
165 Some(actual) if actual == *expected => {} _ => {
167 let reason = "session resume key mismatch".to_string();
168 send_handshake(
169 tx,
170 &HandshakeMessage::Sorry(vox_types::Sorry {
171 reason: reason.clone(),
172 }),
173 )
174 .await?;
175 return Err(HandshakeError::Protocol(reason));
176 }
177 }
178 }
179
180 let our_settings = ConnectionSettings {
182 parity: hello.parity.other(),
183 ..settings
184 };
185
186 let our_resume_key = if resumable {
188 Some(fresh_resume_key()?)
189 } else {
190 None
191 };
192
193 let our_schema = message_schema();
194
195 let hy = vox_types::HelloYourself {
197 connection_settings: our_settings.clone(),
198 message_payload_schema: our_schema.clone(),
199 supports_retry,
200 resume_key: our_resume_key.as_ref().map(ResumeKeyBytes::from_key),
201 };
202 send_handshake(tx, &HandshakeMessage::HelloYourself(hy)).await?;
203
204 let response = recv_handshake(rx).await?;
206 match response {
207 HandshakeMessage::LetsGo(_) => {}
208 HandshakeMessage::Sorry(sorry) => return Err(HandshakeError::Sorry(sorry.reason)),
209 _ => return Err(HandshakeError::Protocol("expected LetsGo or Sorry".into())),
210 }
211
212 let peer_resume_key = hello.resume_key.as_ref().and_then(|k| k.to_key());
213
214 Ok(HandshakeResult {
215 role: SessionRole::Acceptor,
216 our_settings,
217 peer_settings: hello.connection_settings,
218 peer_supports_retry: hello.supports_retry,
219 session_resume_key: our_resume_key,
220 peer_resume_key,
221 our_schema,
222 peer_schema: hello.message_payload_schema,
223 })
224}
225
226fn fresh_resume_key() -> Result<SessionResumeKey, HandshakeError> {
227 let mut bytes = [0u8; 16];
228 getrandom::fill(&mut bytes).map_err(|error| {
229 HandshakeError::Protocol(format!("failed to generate session key: {error}"))
230 })?;
231 Ok(SessionResumeKey(bytes))
232}