1use vox_types::{
2 ConnectionSettings, HandshakeMessage, HandshakeResult, LinkRx, LinkTx, ResumeKeyBytes, Schema,
3 SessionResumeKey, SessionRole,
4};
5
6#[derive(Debug)]
7pub enum HandshakeError {
8 Io(std::io::Error),
9 Encode(String),
10 Decode(String),
11 PeerClosed,
12 Protocol(String),
13 Sorry(String),
14 NotResumable,
15}
16
17impl std::fmt::Display for HandshakeError {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 match self {
20 Self::Io(e) => write!(f, "handshake io error: {e}"),
21 Self::Encode(e) => write!(f, "handshake encode error: {e}"),
22 Self::Decode(e) => write!(f, "handshake decode error: {e}"),
23 Self::PeerClosed => write!(f, "peer closed during handshake"),
24 Self::Protocol(msg) => write!(f, "handshake protocol error: {msg}"),
25 Self::Sorry(reason) => write!(f, "handshake rejected: {reason}"),
26 Self::NotResumable => write!(f, "session is not resumable"),
27 }
28 }
29}
30
31impl std::error::Error for HandshakeError {}
32
33fn message_schema() -> Vec<Schema> {
35 vox_types::extract_schemas(<vox_types::Message<'static> as facet::Facet<'static>>::SHAPE)
36 .expect("schema extraction")
37 .schemas
38 .clone()
39}
40
41async fn send_handshake<Tx: LinkTx>(tx: &Tx, msg: &HandshakeMessage) -> Result<(), HandshakeError> {
43 let bytes = facet_cbor::to_vec(msg).map_err(|e| HandshakeError::Encode(e.to_string()))?;
44 vox_types::dlog!(
45 "[handshake] send {:?} ({} bytes)",
46 handshake_tag(msg),
47 bytes.len()
48 );
49 tx.send(bytes).await.map_err(HandshakeError::Io)
50}
51
52async fn recv_handshake<Rx: LinkRx>(rx: &mut Rx) -> Result<HandshakeMessage, HandshakeError> {
54 let backing = rx
55 .recv()
56 .await
57 .map_err(|error| HandshakeError::Io(std::io::Error::other(error.to_string())))?
58 .ok_or(HandshakeError::PeerClosed)?;
59 vox_types::dlog!(
60 "[handshake] recv raw frame ({} bytes)",
61 backing.as_bytes().len()
62 );
63 let msg = facet_cbor::from_slice(backing.as_bytes())
64 .map_err(|e| HandshakeError::Decode(e.to_string()))?;
65 vox_types::dlog!("[handshake] recv {:?}", handshake_tag(&msg));
66 Ok(msg)
67}
68
69fn handshake_tag(msg: &HandshakeMessage) -> &'static str {
70 match msg {
71 HandshakeMessage::Hello(_) => "Hello",
72 HandshakeMessage::HelloYourself(_) => "HelloYourself",
73 HandshakeMessage::LetsGo(_) => "LetsGo",
74 HandshakeMessage::Sorry(_) => "Sorry",
75 }
76}
77
78pub async fn handshake_as_initiator<Tx: LinkTx, Rx: LinkRx>(
87 tx: &Tx,
88 rx: &mut Rx,
89 settings: ConnectionSettings,
90 supports_retry: bool,
91 resume_key: Option<&SessionResumeKey>,
92 metadata: vox_types::Metadata<'static>,
93) -> Result<HandshakeResult, HandshakeError> {
94 let our_schema = message_schema();
95
96 let hello = vox_types::Hello {
97 parity: settings.parity,
98 connection_settings: settings.clone(),
99 message_payload_schema: our_schema.clone(),
100 supports_retry,
101 resume_key: resume_key.map(ResumeKeyBytes::from_key),
102 metadata,
103 };
104
105 send_handshake(tx, &HandshakeMessage::Hello(hello)).await?;
107
108 let response = recv_handshake(rx).await?;
110 let hy = match response {
111 HandshakeMessage::HelloYourself(hy) => hy,
112 HandshakeMessage::Sorry(sorry) => return Err(HandshakeError::Sorry(sorry.reason)),
113 _ => {
114 return Err(HandshakeError::Protocol(
115 "expected HelloYourself or Sorry".into(),
116 ));
117 }
118 };
119
120 send_handshake(tx, &HandshakeMessage::LetsGo(vox_types::LetsGo {})).await?;
123
124 let session_resume_key = hy.resume_key.as_ref().and_then(|k| k.to_key());
125
126 Ok(HandshakeResult {
127 role: SessionRole::Initiator,
128 our_settings: settings,
129 peer_settings: hy.connection_settings,
130 peer_supports_retry: hy.supports_retry,
131 session_resume_key,
132 peer_resume_key: None, our_schema,
134 peer_schema: hy.message_payload_schema,
135 peer_metadata: hy.metadata,
136 })
137}
138
139pub async fn handshake_as_acceptor<Tx: LinkTx, Rx: LinkRx>(
148 tx: &Tx,
149 rx: &mut Rx,
150 settings: ConnectionSettings,
151 supports_retry: bool,
152 resumable: bool,
153 expected_resume_key: Option<&SessionResumeKey>,
154 metadata: vox_types::Metadata<'static>,
155) -> Result<HandshakeResult, HandshakeError> {
156 let hello = match recv_handshake(rx).await? {
158 HandshakeMessage::Hello(h) => h,
159 _ => return Err(HandshakeError::Protocol("expected Hello".into())),
160 };
161
162 if let Some(expected) = expected_resume_key {
164 let actual = hello.resume_key.as_ref().and_then(|k| k.to_key());
165 match actual {
166 Some(actual) if actual == *expected => {} _ => {
168 let reason = "session resume key mismatch".to_string();
169 send_handshake(
170 tx,
171 &HandshakeMessage::Sorry(vox_types::Sorry {
172 reason: reason.clone(),
173 }),
174 )
175 .await?;
176 return Err(HandshakeError::Protocol(reason));
177 }
178 }
179 }
180
181 let our_settings = ConnectionSettings {
183 parity: hello.parity.other(),
184 ..settings
185 };
186
187 let our_resume_key = if resumable {
189 Some(fresh_resume_key()?)
190 } else {
191 None
192 };
193
194 let our_schema = message_schema();
195
196 let hy = vox_types::HelloYourself {
198 connection_settings: our_settings.clone(),
199 message_payload_schema: our_schema.clone(),
200 supports_retry,
201 resume_key: our_resume_key.as_ref().map(ResumeKeyBytes::from_key),
202 metadata,
203 };
204 send_handshake(tx, &HandshakeMessage::HelloYourself(hy)).await?;
205
206 let response = recv_handshake(rx).await?;
208 match response {
209 HandshakeMessage::LetsGo(_) => {}
210 HandshakeMessage::Sorry(sorry) => return Err(HandshakeError::Sorry(sorry.reason)),
211 _ => return Err(HandshakeError::Protocol("expected LetsGo or Sorry".into())),
212 }
213
214 let peer_resume_key = hello.resume_key.as_ref().and_then(|k| k.to_key());
215
216 Ok(HandshakeResult {
217 role: SessionRole::Acceptor,
218 our_settings,
219 peer_settings: hello.connection_settings,
220 peer_supports_retry: hello.supports_retry,
221 session_resume_key: our_resume_key,
222 peer_resume_key,
223 our_schema,
224 peer_schema: hello.message_payload_schema,
225 peer_metadata: hello.metadata,
226 })
227}
228
229fn fresh_resume_key() -> Result<SessionResumeKey, HandshakeError> {
230 let mut bytes = [0u8; 16];
231 getrandom::fill(&mut bytes).map_err(|error| {
232 HandshakeError::Protocol(format!("failed to generate session key: {error}"))
233 })?;
234 Ok(SessionResumeKey(bytes))
235}