1use super::{DeviceEnrollment, Error, Result, ServerPairUrl};
3use crate::NetworkAccount;
4use futures::{
5 stream::{SplitSink, SplitStream},
6 SinkExt, StreamExt,
7};
8use prost::bytes::Bytes;
9use snow::{Builder, HandshakeState, Keypair, TransportState};
10use sos_account::Account;
11use sos_backend::BackendTarget;
12use sos_core::{
13 device::{DeviceMetaData, DevicePublicKey, TrustedDevice},
14 events::DeviceEvent,
15 AccountId, Origin,
16};
17use sos_protocol::{
18 network_client::WebSocketRequest,
19 pairing_message,
20 tokio_tungstenite::{
21 connect_async,
22 tungstenite::{
23 protocol::{frame::coding::CloseCode, CloseFrame, Message},
24 Utf8Bytes,
25 },
26 MaybeTlsStream, WebSocketStream,
27 },
28 AccountSync, PairingConfirm, PairingMessage, PairingReady,
29 PairingRequest, ProtoMessage, RelayHeader, RelayPacket, RelayPayload,
30 SyncOptions,
31};
32use std::collections::HashSet;
33use tokio::{net::TcpStream, sync::mpsc};
34use url::Url;
35
36const PATTERN: &str = "Noise_XXpsk3_25519_ChaChaPoly_BLAKE2s";
37const RELAY_PATH: &str = "api/v1/relay";
38const TAGLEN: usize = 16;
41
42enum Tunnel {
44 Handshake(HandshakeState),
46 Transport(TransportState),
48}
49
50type WsSink = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
51type WsStream = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
52
53#[derive(Debug)]
55enum PairProtocolState {
56 Pending,
58 Handshake,
60 PskHandshake,
62 Done,
64}
65
66#[derive(Debug)]
67enum IncomingAction {
68 Reply(PairProtocolState, RelayPacket),
69 HandleMessage(PairingMessage),
70}
71
72async fn listen(
74 mut rx: WsStream,
75 tx: mpsc::Sender<RelayPacket>,
76 close_tx: mpsc::Sender<()>,
77) {
78 while let Some(message) = rx.next().await {
79 match message {
80 Ok(message) => {
81 if let Message::Binary(msg) = message {
82 let buf: Bytes = msg.into();
83 match RelayPacket::decode_proto(buf).await {
84 Ok(result) => {
85 if let Err(e) = tx.send(result).await {
86 tracing::error!(error = ?e);
87 }
88 }
89 Err(e) => {
90 tracing::error!(error = ?e);
91 let _ = close_tx.send(()).await;
92 break;
93 }
94 }
95 }
96 }
97 Err(e) => {
98 tracing::error!(error = ?e);
99 let _ = close_tx.send(()).await;
100 break;
101 }
102 }
103 }
104 tracing::debug!("pairing::websocket::connection_closed");
105}
106
107pub struct OfferPairing<'a> {
110 keypair: Keypair,
112 account: &'a mut NetworkAccount,
114 share_url: ServerPairUrl,
116 tunnel: Option<Tunnel>,
118 tx: WsSink,
120 state: PairProtocolState,
122 is_inverted: bool,
124}
125
126impl<'a> OfferPairing<'a> {
127 pub async fn new(
129 account: &'a mut NetworkAccount,
130 url: Url,
131 ) -> Result<(OfferPairing<'a>, WsStream)> {
132 let builder = Builder::new(PATTERN.parse()?);
133 let keypair = builder.generate_keypair()?;
134 let share_url = ServerPairUrl::new(
135 *account.account_id(),
136 url.clone(),
137 keypair.public.clone(),
138 );
139 Self::new_connection(account, share_url, keypair, false).await
140 }
141
142 pub async fn new_inverted(
145 account: &'a mut NetworkAccount,
146 share_url: ServerPairUrl,
147 ) -> Result<(OfferPairing<'a>, WsStream)> {
148 let builder = Builder::new(PATTERN.parse()?);
149 let keypair = builder.generate_keypair()?;
150 Self::new_connection(account, share_url, keypair, true).await
151 }
152
153 async fn new_connection(
154 account: &'a mut NetworkAccount,
155 share_url: ServerPairUrl,
156 keypair: Keypair,
157 is_inverted: bool,
158 ) -> Result<(OfferPairing<'a>, WsStream)> {
159 let psk = share_url.pre_shared_key().to_vec();
160 let tunnel = if is_inverted {
161 Builder::new(PATTERN.parse()?)
162 .local_private_key(&keypair.private)
163 .remote_public_key(share_url.public_key())
164 .psk(3, &psk)
165 .build_initiator()?
166 } else {
167 Builder::new(PATTERN.parse()?)
168 .local_private_key(&keypair.private)
169 .psk(3, &psk)
170 .build_responder()?
171 };
172
173 let mut request = WebSocketRequest::new(
174 *account.account_id(),
175 share_url.server(),
176 RELAY_PATH,
177 )?;
178 request
179 .uri
180 .query_pairs_mut()
181 .append_pair("public_key", &hex::encode(&keypair.public));
182
183 let (socket, _) = connect_async(request).await?;
184 let (tx, rx) = socket.split();
185 Ok((
186 Self {
187 keypair,
188 account,
189 share_url,
190 tunnel: Some(Tunnel::Handshake(tunnel)),
191 tx,
192 state: PairProtocolState::Pending,
193 is_inverted,
194 },
195 rx,
196 ))
197 }
198
199 pub fn share_url(&self) -> &ServerPairUrl {
201 &self.share_url
202 }
203
204 pub async fn run(
206 &mut self,
207 stream: WsStream,
208 mut shutdown_rx: mpsc::Receiver<()>,
209 ) -> Result<()> {
210 if self.is_inverted {
211 self.noise_send_e().await?;
213 self.state = PairProtocolState::Handshake;
214 }
215
216 let (offer_tx, mut offer_rx) = mpsc::channel::<RelayPacket>(32);
217 let (close_tx, mut close_rx) = mpsc::channel::<()>(1);
218 tokio::task::spawn(listen(stream, offer_tx, close_tx));
219 loop {
220 tokio::select! {
221 biased;
222 Some(_) = shutdown_rx.recv() => {
224 tracing::debug!("pairing::offer::shutdown_received");
225 if let Err(error) = self.tx.send(Message::Close(Some(CloseFrame {
226 code: CloseCode::Normal,
227 reason: Utf8Bytes::from_static("closed"),
228 }))).await {
229 tracing::error!(
230 error = %error,
231 "pairing::offer::websocket_close_frame::error");
232 }
233 break;
234 }
235 Some(_) = close_rx.recv() => {
237 break;
238 }
239 Some(event) = offer_rx.recv() => {
241 self.incoming(event).await?;
242 if self.is_finished() {
243 break;
244 }
245 }
246 }
247 }
248
249 Ok(())
250 }
251
252 pub fn is_finished(&self) -> bool {
254 matches!(&self.state, PairProtocolState::Done)
255 }
256
257 async fn incoming(&mut self, packet: RelayPacket) -> Result<()> {
259 if packet.header.as_ref().unwrap().to_public_key
260 != self.keypair.public
261 {
262 return Err(Error::NotForMe);
263 }
264
265 let action = if !self.is_inverted {
266 match (&self.state, packet.is_handshake()) {
267 (PairProtocolState::Pending, true) => {
268 let reply = self.noise_read_e(&packet).await?;
269 IncomingAction::Reply(PairProtocolState::Handshake, reply)
270 }
271 (PairProtocolState::Handshake, true) => {
272 let reply = self.noise_read_s(&packet).await?;
273 IncomingAction::Reply(
274 PairProtocolState::PskHandshake,
275 reply,
276 )
277 }
278 (PairProtocolState::PskHandshake, false) => {
279 if let Some(Tunnel::Transport(transport)) =
280 self.tunnel.as_mut()
281 {
282 let payload = packet.payload.as_ref().unwrap();
283 let body = payload.body.as_ref().unwrap();
284 let (len, buf) =
285 (body.length as usize, &body.contents);
286
287 IncomingAction::HandleMessage(
288 decrypt(transport, len, buf).await?,
289 )
290 } else {
291 unreachable!();
292 }
293 }
294 _ => {
295 return Err(Error::BadState);
296 }
297 }
298 } else {
299 match (&self.state, packet.is_handshake()) {
300 (PairProtocolState::Handshake, true) => {
301 let reply = self.noise_send_s(&packet).await?;
302 IncomingAction::Reply(
303 PairProtocolState::PskHandshake,
304 reply,
305 )
306 }
307 (PairProtocolState::PskHandshake, false) => {
308 if let Some(Tunnel::Transport(transport)) =
309 self.tunnel.as_mut()
310 {
311 let payload = packet.payload.as_ref().unwrap();
312 let body = payload.body.as_ref().unwrap();
313 let (len, buf) =
314 (body.length as usize, &body.contents);
315
316 IncomingAction::HandleMessage(
317 decrypt(transport, len, buf).await?,
318 )
319 } else {
320 unreachable!();
321 }
322 }
323 _ => {
324 return Err(Error::BadState);
325 }
326 }
327 };
328
329 match action {
330 IncomingAction::Reply(next_state, reply) => {
331 self.state = next_state;
332 let buffer = reply.encode_prefixed().await?;
333 self.tx.send(Message::Binary(buffer.into())).await?;
334 }
335 IncomingAction::HandleMessage(msg) => {
336 let msg = msg.inner.unwrap();
337 if let pairing_message::Inner::Ready(_) = msg {
341 let payload = if let Some(Tunnel::Transport(transport)) =
342 self.tunnel_mut()
343 {
344 let private_message = PairingReady {};
345 encrypt(
346 transport,
347 PairingMessage {
348 inner: Some(pairing_message::Inner::Ready(
349 private_message,
350 )),
351 },
352 )
353 .await?
354 } else {
355 unreachable!();
356 };
357 let reply = RelayPacket {
358 header: Some(RelayHeader {
359 to_public_key: packet
360 .header
361 .as_ref()
362 .unwrap()
363 .from_public_key
364 .clone(),
365 from_public_key: self.keypair().public.clone(),
366 }),
367 payload: Some(payload),
368 };
369
370 let buffer = reply.encode_prefixed().await?;
371 self.tx.send(Message::Binary(buffer.into())).await?;
372 } else if let pairing_message::Inner::Request(message) = msg {
373 tracing::debug!("<- device");
374
375 let device_bytes = message.device_meta_data;
376 let device: DeviceMetaData =
377 serde_json::from_slice(&device_bytes)?;
378
379 let (device_signer, manager) =
380 self.account.new_device_vault().await?;
381 let device_vault = manager.into_vault_buffer().await?;
382 let servers = self.account.servers().await;
383 let account_name = self.account.account_name().await?;
384
385 self.register_device(device_signer.public_key(), device)
386 .await?;
387
388 let private_message = PairingConfirm {
389 account_id: message.account_id,
390 account_name,
391 device_signing_key: device_signer.to_bytes().to_vec(),
392 device_vault,
393 servers: servers
394 .into_iter()
395 .map(|s| s.into())
396 .collect(),
397 };
398
399 let payload = if let Some(Tunnel::Transport(transport)) =
400 self.tunnel.as_mut()
401 {
402 encrypt(
403 transport,
404 PairingMessage {
405 inner: Some(pairing_message::Inner::Confirm(
406 private_message,
407 )),
408 },
409 )
410 .await?
411 } else {
412 unreachable!();
413 };
414
415 let reply = RelayPacket {
416 header: Some(RelayHeader {
417 to_public_key: packet
418 .header
419 .unwrap()
420 .from_public_key
421 .to_vec(),
422 from_public_key: self.keypair.public.to_vec(),
423 }),
424 payload: Some(payload),
425 };
426
427 tracing::debug!("-> private-key");
428 let buffer = reply.encode_prefixed().await?;
429 self.tx.send(Message::Binary(buffer.into())).await?;
430 self.state = PairProtocolState::Done;
431 } else {
432 return Err(Error::BadState);
433 }
434 }
435 }
436
437 Ok(())
438 }
439
440 async fn register_device(
441 &mut self,
442 public_key: DevicePublicKey,
443 device: DeviceMetaData,
444 ) -> Result<()> {
445 let trusted_device =
446 TrustedDevice::new(public_key, Some(device), None);
447 let events: Vec<DeviceEvent> =
449 vec![DeviceEvent::Trust(trusted_device)];
450 {
451 self.account
452 .patch_devices_unchecked(events.as_slice())
453 .await?;
454 }
455
456 let origins = vec![self.share_url.server().clone().into()];
465 let options = SyncOptions {
466 origins,
467 ..Default::default()
468 };
469 if let Some(sync_error) =
470 self.account.sync_with_options(&options).await.first_error()
471 {
472 return Err(Error::DevicePatchSync(Box::new(sync_error)));
473 }
474
475 if let Some(sync_error) =
481 self.account.sync_with_options(&options).await.first_error()
482 {
483 return Err(Error::EnrollSync(Box::new(sync_error)));
484 }
485
486 Ok(())
487 }
488}
489
490impl<'a> NoiseTunnel for OfferPairing<'a> {
491 async fn send(&mut self, message: Message) -> Result<()> {
492 Ok(self.tx.send(message).await?)
493 }
494
495 fn pairing_public_key(&self) -> &[u8] {
496 self.share_url.public_key()
497 }
498
499 fn keypair(&self) -> &Keypair {
500 &self.keypair
501 }
502
503 fn tunnel_mut(&mut self) -> Option<&mut Tunnel> {
504 self.tunnel.as_mut()
505 }
506
507 fn into_transport_mode(&mut self) -> Result<()> {
508 let tunnel = self.tunnel.take().unwrap();
509 if let Tunnel::Handshake(state) = tunnel {
510 self.tunnel =
511 Some(Tunnel::Transport(state.into_transport_mode()?));
512 }
513 Ok(())
514 }
515}
516
517pub struct AcceptPairing<'a> {
519 keypair: Keypair,
521 device: &'a DeviceMetaData,
523 target: BackendTarget,
525 share_url: ServerPairUrl,
527 tunnel: Option<Tunnel>,
529 tx: WsSink,
531 state: PairProtocolState,
533 enrollment: Option<DeviceEnrollment>,
535 is_inverted: bool,
537}
538
539impl<'a> AcceptPairing<'a> {
540 pub async fn new(
542 share_url: ServerPairUrl,
543 device: &'a DeviceMetaData,
544 target: BackendTarget,
545 ) -> Result<(AcceptPairing<'a>, WsStream)> {
546 let builder = Builder::new(PATTERN.parse()?);
547 let keypair = builder.generate_keypair()?;
548 Self::new_connection(share_url, device, target, keypair, false).await
549 }
550
551 pub async fn new_inverted(
553 account_id: AccountId,
554 server: Url,
555 device: &'a DeviceMetaData,
556 target: BackendTarget,
557 ) -> Result<(ServerPairUrl, AcceptPairing<'a>, WsStream)> {
558 let builder = Builder::new(PATTERN.parse()?);
559 let keypair = builder.generate_keypair()?;
560 let share_url =
561 ServerPairUrl::new(account_id, server, keypair.public.clone());
562 let (pairing, stream) = Self::new_connection(
563 share_url.clone(),
564 device,
565 target,
566 keypair,
567 true,
568 )
569 .await?;
570 Ok((share_url, pairing, stream))
571 }
572
573 async fn new_connection(
574 share_url: ServerPairUrl,
575 device: &'a DeviceMetaData,
576 target: BackendTarget,
577 keypair: Keypair,
578 is_inverted: bool,
579 ) -> Result<(AcceptPairing<'a>, WsStream)> {
580 let psk = share_url.pre_shared_key().to_vec();
581 let tunnel = if is_inverted {
582 Builder::new(PATTERN.parse()?)
583 .local_private_key(&keypair.private)
584 .psk(3, &psk)
585 .build_responder()?
586 } else {
587 Builder::new(PATTERN.parse()?)
588 .local_private_key(&keypair.private)
589 .remote_public_key(share_url.public_key())
590 .psk(3, &psk)
591 .build_initiator()?
592 };
593
594 let mut request = WebSocketRequest::new(
595 *share_url.account_id(),
596 share_url.server(),
597 RELAY_PATH,
598 )?;
599 request
600 .uri
601 .query_pairs_mut()
602 .append_pair("public_key", &hex::encode(&keypair.public));
603 let (socket, _) = connect_async(request).await?;
604 let (tx, rx) = socket.split();
605 Ok((
606 Self {
607 keypair,
608 device,
609 share_url,
610 target,
611 tunnel: Some(Tunnel::Handshake(tunnel)),
612 tx,
613 state: PairProtocolState::Pending,
614 enrollment: None,
615 is_inverted,
616 },
617 rx,
618 ))
619 }
620
621 pub async fn run(
623 &mut self,
624 stream: WsStream,
625 mut shutdown_rx: mpsc::Receiver<()>,
626 ) -> Result<()> {
627 if !self.is_inverted {
628 self.noise_send_e().await?;
630 self.state = PairProtocolState::Handshake;
631 }
632
633 let (offer_tx, mut offer_rx) = mpsc::channel::<RelayPacket>(32);
635 let (close_tx, mut close_rx) = mpsc::channel::<()>(1);
636 tokio::task::spawn(listen(stream, offer_tx, close_tx));
637
638 loop {
639 tokio::select! {
640 biased;
641 event = shutdown_rx.recv() => {
642 if event.is_some() {
643 let _ = self.tx.send(Message::Close(Some(CloseFrame {
644 code: CloseCode::Normal,
645 reason: Utf8Bytes::from_static("closed"),
646 }))).await;
647 break;
648 }
649 }
650 event = offer_rx.recv() => {
651 if let Some(event) = event {
652 self.incoming(event).await?;
653 if self.is_finished() {
654 break;
655 }
656 }
657 }
658 event = close_rx.recv() => {
659 if event.is_some() {
660 break;
661 }
662 }
663 }
664 }
665
666 Ok(())
667 }
668
669 pub fn is_finished(&self) -> bool {
671 matches!(&self.state, PairProtocolState::Done)
672 }
673
674 pub fn take_enrollment(self) -> Result<DeviceEnrollment> {
678 self.enrollment.ok_or(Error::NoEnrollment)
679 }
680
681 async fn incoming(&mut self, packet: RelayPacket) -> Result<()> {
683 if packet.header.as_ref().unwrap().to_public_key
684 != self.keypair.public
685 {
686 return Err(Error::NotForMe);
687 }
688
689 let action = if !self.is_inverted {
690 match (&self.state, packet.is_handshake()) {
691 (PairProtocolState::Handshake, true) => {
692 let reply = self.noise_send_s(&packet).await?;
693 IncomingAction::Reply(
694 PairProtocolState::PskHandshake,
695 reply,
696 )
697 }
698 (PairProtocolState::PskHandshake, false) => {
699 if let Some(Tunnel::Transport(transport)) =
700 self.tunnel.as_mut()
701 {
702 let payload = packet.payload.as_ref().unwrap();
703 let body = payload.body.as_ref().unwrap();
704 let (len, buf) =
705 (body.length as usize, &body.contents);
706
707 IncomingAction::HandleMessage(
708 decrypt(transport, len, buf).await?,
709 )
710 } else {
711 unreachable!();
712 }
713 }
714 _ => {
715 return Err(Error::BadState);
716 }
717 }
718 } else {
719 match (&self.state, packet.is_handshake()) {
720 (PairProtocolState::Pending, true) => {
721 let reply = self.noise_read_e(&packet).await?;
722 IncomingAction::Reply(PairProtocolState::Handshake, reply)
723 }
724 (PairProtocolState::Handshake, true) => {
725 let reply = self.noise_read_s(&packet).await?;
726 IncomingAction::Reply(
727 PairProtocolState::PskHandshake,
728 reply,
729 )
730 }
731 (
732 PairProtocolState::PskHandshake,
733 false,
734 ) => {
736 if let Some(Tunnel::Transport(transport)) =
737 self.tunnel.as_mut()
738 {
739 let payload = packet.payload.as_ref().unwrap();
740 let body = payload.body.as_ref().unwrap();
741 let (len, buf) =
742 (body.length as usize, &body.contents);
743
744 IncomingAction::HandleMessage(
745 decrypt(transport, len, buf).await?,
746 )
747 } else {
748 unreachable!();
749 }
750 }
751 _ => {
752 return Err(Error::BadState);
753 }
754 }
755 };
756
757 match action {
758 IncomingAction::Reply(next_state, reply) => {
759 self.state = next_state;
760
761 let buffer = reply.encode_prefixed().await?;
762 self.tx.send(Message::Binary(buffer.into())).await?;
763 }
764 IncomingAction::HandleMessage(msg) => {
765 let msg = msg.inner.unwrap();
766
767 if let pairing_message::Inner::Ready(_) = msg {
770 tracing::debug!("<- ready");
771 if let Some(Tunnel::Transport(transport)) =
772 self.tunnel.as_mut()
773 {
774 let device_bytes = serde_json::to_vec(&self.device)?;
775
776 let private_message = PairingRequest {
777 device_meta_data: device_bytes,
778 account_id: self
779 .share_url
780 .account_id()
781 .to_string(),
782 };
783
784 let payload = encrypt(
785 transport,
786 PairingMessage {
787 inner: Some(pairing_message::Inner::Request(
788 private_message,
789 )),
790 },
791 )
792 .await?;
793 let reply = RelayPacket {
794 header: Some(RelayHeader {
795 to_public_key: packet
796 .header
797 .as_ref()
798 .unwrap()
799 .from_public_key
800 .to_vec(),
801 from_public_key: self.keypair.public.to_vec(),
802 }),
803 payload: Some(payload),
804 };
805 tracing::debug!("-> device");
806 let buffer = reply.encode_prefixed().await?;
807 self.tx.send(Message::Binary(buffer.into())).await?;
808 } else {
809 unreachable!();
810 }
811 } else if let pairing_message::Inner::Confirm(confirmation) =
812 msg
813 {
814 self.create_enrollment(confirmation).await?;
815 self.state = PairProtocolState::Done;
816 } else {
817 return Err(Error::BadState);
818 }
819 }
820 }
821
822 Ok(())
823 }
824
825 async fn create_enrollment(
832 &mut self,
833 confirmation: PairingConfirm,
834 ) -> Result<()> {
835 let device_signing_key: [u8; 32] =
839 confirmation.device_signing_key.as_slice().try_into()?;
840 let device_vault = confirmation.device_vault;
841 let mut servers = HashSet::new();
842 for server in confirmation.servers {
843 servers.insert(server.try_into()?);
844 }
845 let account_id: AccountId = confirmation.account_id.parse()?;
846
847 let server = self.share_url.server().clone();
850 let origin: Origin = server.into();
851 let enrollment = DeviceEnrollment::new(
854 self.target.clone(),
855 account_id,
856 confirmation.account_name,
857 origin,
858 device_signing_key.try_into()?,
859 device_vault,
860 servers,
861 )
862 .await?;
863 self.enrollment = Some(enrollment);
864
865 Ok(())
866 }
867}
868
869async fn encrypt<T: prost::Message>(
871 transport: &mut TransportState,
872 message: T,
873) -> crate::pairing::Result<RelayPayload> {
874 let mut plaintext = Vec::new();
875 message.encode(&mut plaintext)?;
876 let mut contents = vec![0u8; plaintext.len() + TAGLEN];
877 let length = transport.write_message(&plaintext, &mut contents)?;
878 Ok(RelayPayload::new_transport(length, contents))
879}
880
881async fn decrypt<T: prost::Message + Default>(
883 transport: &mut TransportState,
884 length: usize,
885 message: &[u8],
886) -> crate::pairing::Result<T> {
887 let mut contents = vec![0; length];
888 transport.read_message(&message[..length], &mut contents)?;
889 let message = &contents[..contents.len() - TAGLEN];
890 let message: prost::bytes::Bytes = message.to_vec().into();
891 Ok(T::decode(message)?)
892}
893
894impl<'a> NoiseTunnel for AcceptPairing<'a> {
895 async fn send(&mut self, message: Message) -> Result<()> {
896 Ok(self.tx.send(message).await?)
897 }
898
899 fn pairing_public_key(&self) -> &[u8] {
900 self.share_url.public_key()
901 }
902
903 fn keypair(&self) -> &Keypair {
904 &self.keypair
905 }
906
907 fn tunnel_mut(&mut self) -> Option<&mut Tunnel> {
908 self.tunnel.as_mut()
909 }
910
911 fn into_transport_mode(&mut self) -> Result<()> {
912 let tunnel = self.tunnel.take().unwrap();
913 if let Tunnel::Handshake(state) = tunnel {
914 self.tunnel =
915 Some(Tunnel::Transport(state.into_transport_mode()?));
916 }
917 Ok(())
918 }
919}
920
921trait NoiseTunnel {
922 async fn send(&mut self, message: Message) -> Result<()>;
924
925 fn pairing_public_key(&self) -> &[u8];
927
928 fn keypair(&self) -> &Keypair;
930
931 fn tunnel_mut(&mut self) -> Option<&mut Tunnel>;
933
934 fn into_transport_mode(&mut self) -> Result<()>;
936
937 async fn noise_send_e(&mut self) -> Result<()> {
939 let buffer = if let Some(Tunnel::Handshake(state)) = self.tunnel_mut()
940 {
941 let mut buf = [0u8; 1024];
942 tracing::debug!("-> e");
944 let len = state.write_message(&[], &mut buf)?;
945 let message = RelayPacket {
946 header: Some(RelayHeader {
947 to_public_key: self.pairing_public_key().to_vec(),
948 from_public_key: self.keypair().public.to_vec(),
949 }),
950 payload: Some(RelayPayload::new_handshake(len, buf.to_vec())),
951 };
952 message.encode_prefixed().await?
953 } else {
954 unreachable!();
955 };
956 self.send(Message::Binary(buffer.into())).await?;
957 Ok(())
958 }
959
960 async fn noise_read_e(
962 &mut self,
963 packet: &RelayPacket,
964 ) -> Result<RelayPacket> {
965 if let (Some(Tunnel::Handshake(state)), true) =
966 (self.tunnel_mut(), packet.is_handshake())
967 {
968 let payload = packet.payload.as_ref().unwrap();
969 let body = payload.body.as_ref().unwrap();
970 let (len, init_msg) = (body.length as usize, &body.contents);
971
972 let mut buf = [0; 1024];
973 let mut reply = [0; 1024];
974 tracing::debug!("<- e");
976 state.read_message(&init_msg[..len], &mut buf)?;
977 tracing::debug!("-> e, ee, s, es");
979 let len = state.write_message(&[], &mut reply)?;
980 Ok(RelayPacket {
981 header: Some(RelayHeader {
982 to_public_key: packet
983 .header
984 .as_ref()
985 .unwrap()
986 .from_public_key
987 .clone(),
988 from_public_key: self.keypair().public.clone(),
989 }),
990 payload: Some(RelayPayload::new_handshake(
991 len,
992 reply.to_vec(),
993 )),
994 })
995 } else {
996 Err(Error::BadState)
997 }
998 }
999
1000 async fn noise_send_s(
1003 &mut self,
1004 packet: &RelayPacket,
1005 ) -> Result<RelayPacket> {
1006 let packet = if let (Some(Tunnel::Handshake(state)), true) =
1007 (self.tunnel_mut(), packet.is_handshake())
1008 {
1009 let payload = packet.payload.as_ref().unwrap();
1010 let body = payload.body.as_ref().unwrap();
1011 let (len, init_msg) = (body.length as usize, &body.contents);
1012
1013 let mut buf = [0; 1024];
1014 let mut reply = [0; 1024];
1015 tracing::debug!("<- e, ee, s, es");
1017 state.read_message(&init_msg[..len], &mut buf)?;
1018 tracing::debug!("-> s, se");
1020 let len = state.write_message(&[], &mut reply)?;
1021 Some(RelayPacket {
1022 header: Some(RelayHeader {
1023 to_public_key: packet
1024 .header
1025 .as_ref()
1026 .unwrap()
1027 .from_public_key
1028 .clone(),
1029 from_public_key: self.keypair().public.clone(),
1030 }),
1031 payload: Some(RelayPayload::new_handshake(
1032 len,
1033 reply.to_vec(),
1034 )),
1035 })
1036 } else {
1037 None
1038 };
1039
1040 if let Some(packet) = packet {
1041 self.into_transport_mode()?;
1042 Ok(packet)
1043 } else {
1044 return Err(Error::BadState);
1045 }
1046 }
1047
1048 async fn noise_read_s(
1051 &mut self,
1052 packet: &RelayPacket,
1053 ) -> Result<RelayPacket> {
1054 if let (Some(Tunnel::Handshake(state)), true) =
1055 (self.tunnel_mut(), packet.is_handshake())
1056 {
1057 let payload = packet.payload.as_ref().unwrap();
1058 let body = payload.body.as_ref().unwrap();
1059 let (len, init_msg) = (body.length as usize, &body.contents);
1060
1061 let mut buf = [0; 1024];
1062 tracing::debug!("<- s, se");
1064 state.read_message(&init_msg[..len], &mut buf)?;
1065
1066 self.into_transport_mode()?;
1067
1068 let payload = if let Some(Tunnel::Transport(transport)) =
1069 self.tunnel_mut()
1070 {
1071 let private_message = PairingReady {};
1072 encrypt(
1073 transport,
1074 PairingMessage {
1075 inner: Some(pairing_message::Inner::Ready(
1076 private_message,
1077 )),
1078 },
1079 )
1080 .await?
1081 } else {
1082 unreachable!();
1083 };
1084 Ok(RelayPacket {
1085 header: Some(RelayHeader {
1086 to_public_key: packet
1087 .header
1088 .as_ref()
1089 .unwrap()
1090 .from_public_key
1091 .clone(),
1092 from_public_key: self.keypair().public.clone(),
1093 }),
1094 payload: Some(payload),
1095 })
1096 } else {
1097 Err(Error::BadState)
1098 }
1099 }
1100}