1#![recursion_limit = "256"]
14#![doc(
15 html_logo_url = "https://github.com/maidsafe/QA/raw/master/Images/maidsafe_logo.png",
16 html_favicon_url = "https://maidsafe.net/img/favicon.ico",
17 test(attr(deny(warnings)))
18)]
19#![forbid(
22 arithmetic_overflow,
23 mutable_transmutes,
24 no_mangle_const_items,
25 unknown_crate_types,
26 unsafe_code
27)]
28#![warn(
30 missing_debug_implementations,
31 missing_docs,
32 trivial_casts,
33 trivial_numeric_casts,
34 unreachable_pub,
35 unused_extern_crates,
36 unused_import_braces,
37 unused_qualifications,
38 unused_results,
39 clippy::unicode_not_nfc,
40 clippy::unwrap_used,
41 clippy::unused_async
42)]
43
44#[macro_use]
45extern crate tracing;
46
47mod error;
48mod listener;
49mod node_link;
50
51pub use self::error::{Error, Result};
52
53use self::node_link::NodeLink;
54
55use sn_interface::{
56 messaging::{
57 data::{DataResponse, Error as MsgError},
58 Dst, MsgId, MsgKind, WireMsg,
59 },
60 types::{NodeId, Participant},
61};
62
63use futures::future::join_all;
64use qp2p::{Endpoint, SendStream, UsrMsgBytes};
65use std::{
66 collections::{BTreeMap, BTreeSet},
67 net::SocketAddr,
68};
69use tokio::{
70 sync::mpsc::{self, Receiver, Sender},
71 task,
72};
73
74static STANDARD_CHANNEL_SIZE: usize = 100_000;
76
77#[derive(Debug)]
79pub enum CommEvent {
80 Msg(MsgReceived),
82 Error {
84 node_id: NodeId,
86 error: Error,
88 },
89}
90
91#[derive(Debug)]
93pub struct MsgReceived {
94 pub sender: Participant,
96 pub wire_msg: WireMsg,
98 pub send_stream: Option<SendStream>,
101}
102
103#[derive(Clone, Debug)]
108pub struct Comm {
109 our_endpoint: Endpoint,
110 public_addr: Option<SocketAddr>,
111 cmd_sender: Sender<CommCmd>,
112}
113
114impl Comm {
115 #[tracing::instrument(skip_all)]
118 pub fn new(
119 local_addr: SocketAddr,
120 mut public_addr: Option<SocketAddr>,
121 ) -> Result<(Self, Receiver<CommEvent>)> {
122 let (our_endpoint, incoming_conns) = Endpoint::builder()
123 .addr(local_addr)
124 .idle_timeout(70_000)
125 .server()?;
126
127 if let Some(ref mut addr) = public_addr {
129 if addr.port() == 0 {
130 addr.set_port(our_endpoint.local_addr().port());
131 }
132 }
133
134 trace!("Creating comms..");
135 let (comm_events_sender, comm_events_receiver) = mpsc::channel(1);
140 let (cmd_sender, cmd_receiver) = mpsc::channel(STANDARD_CHANNEL_SIZE);
141
142 listener::listen_for_connections(comm_events_sender.clone(), incoming_conns);
144
145 process_cmds(our_endpoint.clone(), cmd_receiver, comm_events_sender);
146
147 Ok((
148 Self {
149 our_endpoint,
150 public_addr,
151 cmd_sender,
152 },
153 comm_events_receiver,
154 ))
155 }
156
157 pub fn socket_addr(&self) -> SocketAddr {
159 match self.public_addr {
160 Some(addr) => addr,
161 None => self.our_endpoint.local_addr(),
162 }
163 }
164
165 pub fn close_endpoint(&self) {
167 self.our_endpoint.close()
168 }
169
170 pub fn set_comm_targets(&self, targets: BTreeSet<NodeId>) {
172 self.send_cmd(CommCmd::SetTargets(targets))
176 }
177
178 #[tracing::instrument(skip(self, bytes))]
180 pub fn send_out_bytes(&self, node_id: NodeId, msg_id: MsgId, bytes: UsrMsgBytes) {
181 self.send_cmd(CommCmd::Send {
182 msg_id,
183 node_id,
184 bytes,
185 })
186 }
187
188 #[tracing::instrument(skip(self, bytes))]
190 pub fn send_and_return_response(&self, node_id: NodeId, msg_id: MsgId, bytes: UsrMsgBytes) {
191 self.send_cmd(CommCmd::SendAndReturnResponse {
192 msg_id,
193 node_id,
194 bytes,
195 })
196 }
197
198 #[tracing::instrument(skip(self, node_bytes))]
200 pub fn send_and_respond_on_stream(
201 &self,
202 msg_id: MsgId,
203 node_bytes: BTreeMap<NodeId, UsrMsgBytes>,
204 expected_targets: usize,
205 dst_stream: (Dst, SendStream),
206 ) {
207 self.send_cmd(CommCmd::SendAndRespondOnStream {
208 msg_id,
209 node_bytes,
210 expected_targets,
211 dst_stream,
212 })
213 }
214
215 fn send_cmd(&self, cmd: CommCmd) {
216 let sender = self.cmd_sender.clone();
217 let _handle = task::spawn(async move {
218 let error_msg = format!("Failed to send {cmd:?} on comm cmd channel ");
219 if let Err(error) = sender.send(cmd).await {
220 error!("{error_msg} due to {error}.");
221 }
222 });
223 }
224}
225
226#[derive(custom_debug::Debug)]
228enum CommCmd {
229 Send {
230 msg_id: MsgId,
231 node_id: NodeId,
232 #[debug(skip)]
233 bytes: UsrMsgBytes,
234 },
235 SetTargets(BTreeSet<NodeId>),
236 SendAndReturnResponse {
237 node_id: NodeId,
238 msg_id: MsgId,
239 #[debug(skip)]
240 bytes: UsrMsgBytes,
241 },
242 SendAndRespondOnStream {
243 msg_id: MsgId,
244 #[debug(skip)]
245 node_bytes: BTreeMap<NodeId, UsrMsgBytes>,
246 expected_targets: usize,
247 dst_stream: (Dst, SendStream),
248 },
249}
250
251fn process_cmds(
252 our_endpoint: Endpoint,
253 mut cmd_receiver: Receiver<CommCmd>,
254 comm_events: Sender<CommEvent>,
255) {
256 let _handle = task::spawn(async move {
257 let mut links = BTreeMap::<NodeId, NodeLink>::new();
258 while let Some(cmd) = cmd_receiver.recv().await {
259 trace!("Comms cmd handling: {cmd:?}");
260 match cmd {
261 CommCmd::SetTargets(targets) => {
263 links.retain(|p, _| targets.contains(p));
265 targets.iter().for_each(|node_id| {
267 if links.get(node_id).is_none() {
268 let link = NodeLink::new(*node_id, our_endpoint.clone());
269 let _ = links.insert(*node_id, link);
270 }
271 });
272
273 debug!("CommLinks stored #{:?}", links.len());
274 }
275 CommCmd::Send {
276 msg_id,
277 node_id,
278 bytes,
279 } => {
280 if let Some(link) = get_link(msg_id, node_id, &links, comm_events.clone()) {
281 send(msg_id, link, bytes, comm_events.clone())
282 }
283 }
284 CommCmd::SendAndReturnResponse {
285 node_id,
286 msg_id,
287 bytes,
288 } => {
289 if let Some(link) = get_link(msg_id, node_id, &links, comm_events.clone()) {
290 send_and_return_response(msg_id, link, bytes, comm_events.clone())
291 }
292 }
293 CommCmd::SendAndRespondOnStream {
294 msg_id,
295 node_bytes,
296 expected_targets,
297 dst_stream,
298 } => {
299 let node_bytes = node_bytes
300 .into_iter()
301 .map(|(node_id, bytes)| {
302 let link = get_link(msg_id, node_id, &links, comm_events.clone());
303 (node_id, (link, bytes))
304 })
305 .collect();
306
307 send_and_respond_on_stream(
308 msg_id,
309 node_bytes,
310 expected_targets,
311 dst_stream,
312 comm_events.clone(),
313 )
314 }
315 }
316 }
317 });
318}
319
320fn get_link(
321 msg_id: MsgId,
322 node_id: NodeId,
323 links: &BTreeMap<NodeId, NodeLink>,
324 comm_events: Sender<CommEvent>,
325) -> Option<NodeLink> {
326 debug!("Trying to get {node_id:?} link in order to send: {msg_id:?}");
327 match links.get(&node_id) {
328 Some(link) => Some(link.clone()),
329 None => {
330 error!("Sending message (msg_id: {msg_id:?}) to {node_id:?} failed: unknown node.");
331 send_error(node_id, Error::ConnectingToUnknownNode(msg_id), comm_events);
332 None
333 }
334 }
335}
336
337#[tracing::instrument(skip_all)]
338fn send(msg_id: MsgId, mut link: NodeLink, bytes: UsrMsgBytes, comm_events: Sender<CommEvent>) {
339 let _handle = task::spawn(async move {
340 let (h, d, p) = &bytes;
341 let bytes_len = h.len() + d.len() + p.len();
342 let node_id = link.node_id();
343 trace!("Sending message bytes ({bytes_len} bytes) w/ {msg_id:?} to {node_id:?}");
344 match link.send(msg_id, bytes).await {
345 Ok(()) => {
346 trace!("Msg {msg_id:?} sent to {node_id:?}");
347 }
348 Err(error) => {
349 error!("Sending message (msg_id: {msg_id:?}) to {node_id:?} failed: {error}");
350 send_error(node_id, Error::FailedSend(msg_id), comm_events.clone());
351 }
352 }
353 });
354}
355
356#[tracing::instrument(skip_all)]
357fn send_and_return_response(
358 msg_id: MsgId,
359 link: NodeLink,
360 bytes: UsrMsgBytes,
361 comm_events: Sender<CommEvent>,
362) {
363 let _handle = task::spawn(async move {
364 let (h, d, p) = &bytes;
365 let bytes_len = h.len() + d.len() + p.len();
366 let node_id = link.node_id();
367 trace!("Sending message bytes ({bytes_len} bytes) w/ {msg_id:?} to {node_id:?}");
368
369 let node_response_bytes = match link.send_with_bi_return_response(bytes, msg_id).await {
370 Ok(response_bytes) => {
371 debug!("Node response from {node_id:?} is in for {msg_id:?}");
372 response_bytes
373 }
374 Err(error) => {
375 error!("Sending message (msg_id: {msg_id:?}) to {node_id:?} failed: {error}");
376 send_error(node_id, Error::FailedSend(msg_id), comm_events.clone());
377 return;
378 }
379 };
380 match WireMsg::from(node_response_bytes) {
381 Ok(wire_msg) => {
382 listener::msg_received(
383 wire_msg,
384 Participant::from_node(node_id),
385 None,
386 comm_events.clone(),
387 )
388 .await;
389 }
390 Err(error) => {
391 error!("Failed sending {msg_id:?} to {node_id:?}: {error:?}");
392 send_error(
393 node_id,
394 Error::InvalidMsgReceived(msg_id),
395 comm_events.clone(),
396 );
397 }
398 };
399 });
400}
401
402#[tracing::instrument(skip_all)]
403fn send_and_respond_on_stream(
404 msg_id: MsgId,
405 node_bytes: BTreeMap<NodeId, (Option<NodeLink>, UsrMsgBytes)>,
406 expected_targets: usize,
407 dst_stream: (Dst, SendStream),
408 comm_events: Sender<CommEvent>,
409) {
410 let _handle = task::spawn(async move {
411 let (dst, stream) = dst_stream;
412
413 let tasks = node_bytes
414 .into_iter()
415 .map(|pb| (pb, comm_events.clone()))
416 .map(|((node_id, (link, bytes)), comm_events)| async move {
417 let link = match link {
418 Some(link) => link,
419 None => return (node_id, Err(Error::ConnectingToUnknownNode(msg_id))),
420 };
421
422 let node_response_bytes =
423 match link.send_with_bi_return_response(bytes, msg_id).await {
424 Ok(response_bytes) => response_bytes,
425 Err(error) => {
426 error!("Failed sending {msg_id:?} to {node_id:?}: {error:?}");
427 send_error(node_id, Error::FailedSend(msg_id), comm_events);
428 return (node_id, Err(Error::FailedSend(msg_id)));
429 }
430 };
431
432 debug!("Response from node {node_id:?} is in for {msg_id:?}");
433 (node_id, Ok(node_response_bytes))
434 });
435
436 let node_results: Vec<(NodeId, Result<UsrMsgBytes>)> = join_all(tasks).await;
437
438 let succeeded: Vec<_> = node_results
439 .into_iter()
440 .filter_map(|(node_id, res)| match res {
441 Ok(bytes) => Some((node_id, bytes)),
442 Err(error) => {
443 error!("Failed sending {msg_id:?} to {node_id:?}: {error:?}");
444 send_error(node_id, Error::FailedSend(msg_id), comm_events.clone());
445 None
446 }
447 })
448 .collect();
449
450 let some_failed = expected_targets > succeeded.len();
451 let all_ok_equal = || succeeded.windows(2).all(|w| are_equal(&w[0].1, &w[1].1));
452
453 let response_bytes = if some_failed || !all_ok_equal() {
454 match error_response(dst) {
455 None => {
456 error!("Could not send the error response to client!");
457 return;
458 }
459 Some(bytes) => bytes,
460 }
461 } else {
462 match succeeded.last() {
463 Some((_, bytes)) => bytes.clone(),
464 _ => {
465 error!("Could not send the response to client!");
466 return;
467 }
468 }
469 };
470
471 send_on_stream(msg_id, response_bytes, stream).await;
472 });
473}
474
475#[tracing::instrument(skip_all)]
476fn send_error(node_id: NodeId, error: Error, comm_events: Sender<CommEvent>) {
477 let _handle = task::spawn(async move {
478 let error_msg =
479 format!("Failed to send error {error} of node {node_id} on comm event channel ");
480 if let Err(err) = comm_events.send(CommEvent::Error { node_id, error }).await {
481 error!("{error_msg} due to {err}.")
482 }
483 });
484}
485
486#[tracing::instrument(skip_all)]
487async fn send_on_stream(msg_id: MsgId, bytes: UsrMsgBytes, mut stream: SendStream) {
488 match stream.send_user_msg(bytes).await {
489 Ok(()) => trace!("Response to {msg_id:?} sent to client."),
490 Err(error) => error!("Could not send the response to {msg_id:?} to client due to {error}!"),
491 }
492}
493
494fn error_response(dst: Dst) -> Option<UsrMsgBytes> {
495 let kind = MsgKind::DataResponse(dst.name);
496 let response = DataResponse::NetworkIssue(MsgError::InconsistentStorageNodeResponses);
497 let payload = WireMsg::serialize_msg_payload(&response).ok()?;
498 let wire_msg = WireMsg::new_msg(MsgId::new(), payload, kind, dst);
499 wire_msg.serialize().ok()
500}
501
502#[tracing::instrument(skip_all)]
503fn are_equal(a: &UsrMsgBytes, b: &UsrMsgBytes) -> bool {
504 let (_, _, a_payload) = a;
505 let (_, _, b_payload) = b;
506 if !are_bytes_equal(a_payload.to_vec(), b_payload.to_vec()) {
507 return false;
508 }
509 true
510}
511
512#[tracing::instrument(skip_all)]
513fn are_bytes_equal(one: Vec<u8>, other: Vec<u8>) -> bool {
514 if one.len() != other.len() {
515 return false;
516 }
517 for (a, b) in one.into_iter().zip(other) {
518 if a != b {
519 return false;
520 }
521 }
522 true
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528
529 use sn_interface::{
530 messaging::{
531 data::{ClientMsg, DataQuery},
532 ClientAuth, Dst, MsgId, MsgKind,
533 },
534 types::{ChunkAddress, Keypair, NodeId},
535 };
536
537 use assert_matches::assert_matches;
538 use eyre::Result;
539 use futures::future;
540 use std::{net::Ipv4Addr, time::Duration};
541 use tokio::{
542 net::UdpSocket,
543 sync::mpsc::{self, Receiver},
544 time,
545 };
546
547 const TIMEOUT: Duration = Duration::from_secs(1);
548
549 #[tokio::test]
550 async fn successful_send() -> Result<()> {
551 let (comm, _rx) = Comm::new(local_addr(), None)?;
552
553 let (node0, mut rx0) = new_node_id().await?;
554 let (node1, mut rx1) = new_node_id().await?;
555
556 comm.set_comm_targets([node0, node1].into());
558
559 let node0_msg = new_test_msg(dst(node0))?;
560 let node1_msg = new_test_msg(dst(node1))?;
561
562 comm.send_out_bytes(node0, node0_msg.msg_id(), node0_msg.serialize()?);
563 comm.send_out_bytes(node1, node1_msg.msg_id(), node1_msg.serialize()?);
564
565 if let Some(bytes) = rx0.recv().await {
566 assert_eq!(WireMsg::from(bytes)?, node0_msg);
567 }
568
569 if let Some(bytes) = rx1.recv().await {
570 assert_eq!(WireMsg::from(bytes)?, node1_msg);
571 }
572
573 Ok(())
574 }
575
576 #[tokio::test]
577 async fn failed_send() -> Result<()> {
578 let (comm, mut rx) = Comm::new(local_addr(), None)?;
579
580 let invalid_dst = get_invalid_node().await?;
581 let invalid_addr = invalid_dst.addr();
582 let msg = new_test_msg(dst(invalid_dst))?;
583 comm.send_out_bytes(invalid_dst, msg.msg_id(), msg.serialize()?);
584
585 if let Some(CommEvent::Error { node_id, error }) = rx.recv().await {
586 assert_matches!(error, Error::ConnectingToUnknownNode(_));
588 assert_eq!(node_id.addr(), invalid_addr);
589 }
590
591 comm.set_comm_targets([invalid_dst].into());
593
594 comm.send_out_bytes(invalid_dst, msg.msg_id(), msg.serialize()?);
595
596 if let Some(CommEvent::Error { node_id, error }) = rx.recv().await {
597 assert_matches!(error, Error::FailedSend(_));
598 assert_eq!(node_id.addr(), invalid_addr);
599 }
600
601 Ok(())
602 }
603
604 #[tokio::test]
605 async fn send_after_reconnect() -> Result<()> {
606 let (send_comm, _rx) = Comm::new(local_addr(), None)?;
607
608 let (recv_endpoint, mut incoming_connections) =
609 Endpoint::builder().addr(local_addr()).server()?;
610 let recv_addr = recv_endpoint.local_addr();
611 let name = xor_name::rand::random();
612 let node_id = NodeId::new(name, recv_addr);
613 let msg0 = new_test_msg(dst(node_id))?;
614
615 send_comm.set_comm_targets([node_id].into());
617
618 send_comm.send_out_bytes(node_id, msg0.msg_id(), msg0.serialize()?);
619
620 let mut msg0_received = false;
621
622 {
624 if let Some((_, mut incoming_msgs)) = incoming_connections.next().await {
625 if let Some(msg) = time::timeout(TIMEOUT, incoming_msgs.next()).await?? {
626 assert_eq!(WireMsg::from(msg.0)?, msg0);
627 msg0_received = true;
628 }
629 }
631 assert!(msg0_received);
632 }
633
634 let msg1 = new_test_msg(dst(node_id))?;
635 send_comm.send_out_bytes(node_id, msg1.msg_id(), msg1.serialize()?);
636
637 let mut msg1_received = false;
638
639 if let Some((_, mut incoming_msgs)) = incoming_connections.next().await {
640 if let Some(msg) = time::timeout(TIMEOUT, incoming_msgs.next()).await?? {
641 assert_eq!(WireMsg::from(msg.0)?, msg1);
642 msg1_received = true;
643 }
644 }
645
646 assert!(msg1_received);
647
648 Ok(())
649 }
650
651 #[tokio::test]
652 async fn incoming_connection_lost() -> Result<()> {
653 let (comm0, mut rx0) = Comm::new(local_addr(), None)?;
654 let addr0 = comm0.socket_addr();
655
656 let (comm1, _rx1) = Comm::new(local_addr(), None)?;
657
658 let node_id = NodeId::new(xor_name::rand::random(), addr0);
659 let msg = new_test_msg(dst(node_id))?;
660
661 comm1.set_comm_targets([node_id].into());
663
664 comm1.send_out_bytes(node_id, msg.msg_id(), msg.serialize()?);
666
667 assert_matches!(rx0.recv().await, Some(CommEvent::Msg(MsgReceived { .. })));
668
669 drop(comm1);
671
672 assert_matches!(time::timeout(TIMEOUT, rx0.recv()).await, Err(_));
673
674 Ok(())
675 }
676
677 fn dst(node_id: NodeId) -> Dst {
678 Dst {
679 name: node_id.name(),
680 section_key: bls::SecretKey::random().public_key(),
681 }
682 }
683
684 fn new_test_msg(dst: Dst) -> Result<WireMsg> {
685 let src_keypair = Keypair::new_ed25519();
686
687 let query = DataQuery::GetChunk(ChunkAddress(xor_name::rand::random()));
688 let query = ClientMsg::Query(query);
689 let payload = WireMsg::serialize_msg_payload(&query)?;
690
691 let auth = ClientAuth {
692 public_key: src_keypair.public_key(),
693 signature: src_keypair.sign(&payload),
694 };
695
696 Ok(WireMsg::new_msg(
697 MsgId::new(),
698 payload,
699 MsgKind::Client {
700 auth,
701 is_spend: false,
702 query_index: None,
703 },
704 dst,
705 ))
706 }
707
708 async fn new_node_id() -> Result<(NodeId, Receiver<UsrMsgBytes>)> {
709 let (endpoint, mut incoming_connections) =
710 Endpoint::builder().addr(local_addr()).server()?;
711 let addr = endpoint.local_addr();
712
713 let (tx, rx) = mpsc::channel(1);
714
715 let _handle = tokio::task::spawn(async move {
716 while let Some((_, mut incoming_messages)) = incoming_connections.next().await {
717 while let Ok(Some(msg)) = incoming_messages.next().await {
718 let _ = tx.send(msg.0).await;
719 }
720 }
721 });
722
723 Ok((NodeId::new(xor_name::rand::random(), addr), rx))
724 }
725
726 async fn get_invalid_node() -> Result<NodeId> {
727 let socket = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).await?;
728 let addr = socket.local_addr()?;
729
730 let _handle = tokio::task::spawn(async move {
733 debug!("get invalid participant");
734 future::pending::<()>().await;
735 let _ = socket;
736 });
737
738 Ok(NodeId::new(xor_name::rand::random(), addr))
739 }
740
741 fn local_addr() -> SocketAddr {
742 (Ipv4Addr::LOCALHOST, 0).into()
743 }
744}