use self::outbound::OutboundMessageRequester;
use crate::{
actor::{DhtActor, DhtRequest, DhtRequester},
connectivity::DhtConnectivity,
discovery::{DhtDiscoveryRequest, DhtDiscoveryRequester, DhtDiscoveryService},
inbound,
inbound::{DecryptedDhtMessage, DhtInboundMessage},
logging_middleware::MessageLoggingLayer,
outbound,
outbound::DhtOutboundRequest,
proto::envelope::DhtMessageType,
storage::{DbConnection, StorageError},
store_forward,
store_forward::{StoreAndForwardError, StoreAndForwardRequest, StoreAndForwardRequester, StoreAndForwardService},
tower_filter,
DedupLayer,
DhtActorError,
DhtConfig,
};
use futures::{channel::mpsc, future, Future};
use log::*;
use std::sync::Arc;
use tari_comms::{
connectivity::ConnectivityRequester,
message::{InboundMessage, OutboundMessage},
peer_manager::{NodeIdentity, PeerFeatures, PeerManager},
pipeline::PipelineError,
};
use tari_shutdown::ShutdownSignal;
use thiserror::Error;
use tower::{layer::Layer, Service, ServiceBuilder};
const LOG_TARGET: &str = "comms::dht";
const DHT_ACTOR_CHANNEL_SIZE: usize = 100;
const DHT_DISCOVERY_CHANNEL_SIZE: usize = 100;
const DHT_SAF_SERVICE_CHANNEL_SIZE: usize = 100;
#[derive(Debug, Error)]
pub enum DhtInitializationError {
#[error("Database initialization failed: {0}")]
DatabaseMigrationFailed(#[from] StorageError),
#[error("StoreAndForwardInitializationError: {0}")]
StoreAndForwardInitializationError(#[from] StoreAndForwardError),
#[error("DhtActorInitializationError: {0}")]
DhtActorInitializationError(#[from] DhtActorError),
}
pub struct Dht {
node_identity: Arc<NodeIdentity>,
peer_manager: Arc<PeerManager>,
config: DhtConfig,
outbound_tx: mpsc::Sender<DhtOutboundRequest>,
dht_sender: mpsc::Sender<DhtRequest>,
saf_sender: mpsc::Sender<StoreAndForwardRequest>,
discovery_sender: mpsc::Sender<DhtDiscoveryRequest>,
connectivity: ConnectivityRequester,
}
impl Dht {
pub async fn initialize(
config: DhtConfig,
node_identity: Arc<NodeIdentity>,
peer_manager: Arc<PeerManager>,
outbound_tx: mpsc::Sender<DhtOutboundRequest>,
connectivity: ConnectivityRequester,
shutdown_signal: ShutdownSignal,
) -> Result<Self, DhtInitializationError>
{
let (dht_sender, dht_receiver) = mpsc::channel(DHT_ACTOR_CHANNEL_SIZE);
let (discovery_sender, discovery_receiver) = mpsc::channel(DHT_DISCOVERY_CHANNEL_SIZE);
let (saf_sender, saf_receiver) = mpsc::channel(DHT_SAF_SERVICE_CHANNEL_SIZE);
let dht = Self {
node_identity,
peer_manager,
config,
outbound_tx,
dht_sender,
saf_sender,
connectivity,
discovery_sender,
};
let conn = DbConnection::connect_and_migrate(dht.config.database_url.clone())
.await
.map_err(DhtInitializationError::DatabaseMigrationFailed)?;
dht.connectivity_service(shutdown_signal.clone()).spawn();
dht.store_and_forward_service(conn.clone(), saf_receiver, shutdown_signal.clone())
.spawn();
dht.actor(conn, dht_receiver, shutdown_signal.clone()).spawn();
dht.discovery_service(discovery_receiver, shutdown_signal).spawn();
debug!(target: LOG_TARGET, "Dht initialization complete.");
Ok(dht)
}
fn actor(
&self,
conn: DbConnection,
request_receiver: mpsc::Receiver<DhtRequest>,
shutdown_signal: ShutdownSignal,
) -> DhtActor
{
DhtActor::new(
self.config.clone(),
conn,
Arc::clone(&self.node_identity),
Arc::clone(&self.peer_manager),
self.connectivity.clone(),
self.outbound_requester(),
request_receiver,
shutdown_signal,
)
}
fn discovery_service(
&self,
request_receiver: mpsc::Receiver<DhtDiscoveryRequest>,
shutdown_signal: ShutdownSignal,
) -> DhtDiscoveryService
{
DhtDiscoveryService::new(
self.config.clone(),
Arc::clone(&self.node_identity),
Arc::clone(&self.peer_manager),
self.outbound_requester(),
request_receiver,
shutdown_signal,
)
}
fn connectivity_service(&self, shutdown_signal: ShutdownSignal) -> DhtConnectivity {
DhtConnectivity::new(
self.config.clone(),
self.peer_manager.clone(),
self.node_identity.clone(),
self.connectivity.clone(),
self.dht_requester(),
shutdown_signal,
)
}
fn store_and_forward_service(
&self,
conn: DbConnection,
request_rx: mpsc::Receiver<StoreAndForwardRequest>,
shutdown_signal: ShutdownSignal,
) -> StoreAndForwardService
{
StoreAndForwardService::new(
self.config.clone(),
conn,
self.node_identity.clone(),
self.peer_manager.clone(),
self.dht_requester(),
self.connectivity.clone(),
self.outbound_requester(),
request_rx,
shutdown_signal,
)
}
pub fn outbound_requester(&self) -> OutboundMessageRequester {
OutboundMessageRequester::new(self.outbound_tx.clone())
}
pub fn dht_requester(&self) -> DhtRequester {
DhtRequester::new(self.dht_sender.clone())
}
pub fn discovery_service_requester(&self) -> DhtDiscoveryRequester {
DhtDiscoveryRequester::new(self.discovery_sender.clone(), self.config.discovery_request_timeout)
}
pub fn store_and_forward_requester(&self) -> StoreAndForwardRequester {
StoreAndForwardRequester::new(self.saf_sender.clone())
}
pub fn inbound_middleware_layer<S>(
&self,
) -> impl Layer<
S,
Service = impl Service<
InboundMessage,
Response = (),
Error = PipelineError,
Future = impl Future<Output = Result<(), PipelineError>> + Send,
> + Clone
+ Send,
>
where
S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError> + Clone + Send + Sync + 'static,
S::Future: Send,
{
ServiceBuilder::new()
.layer(inbound::DeserializeLayer::new(self.peer_manager.clone()))
.layer(inbound::ValidateLayer::new(self.config.network))
.layer(DedupLayer::new(self.dht_requester()))
.layer(tower_filter::FilterLayer::new(self.unsupported_saf_messages_filter()))
.layer(MessageLoggingLayer::new(format!(
"Inbound [{}]",
self.node_identity.node_id().short_str()
)))
.layer(inbound::DecryptionLayer::new(Arc::clone(&self.node_identity)))
.layer(store_forward::ForwardLayer::new(
self.config.clone(),
self.outbound_requester(),
self.node_identity.features().contains(PeerFeatures::DHT_STORE_FORWARD),
))
.layer(store_forward::StoreLayer::new(
self.config.clone(),
Arc::clone(&self.peer_manager),
Arc::clone(&self.node_identity),
self.store_and_forward_requester(),
))
.layer(store_forward::MessageHandlerLayer::new(
self.config.clone(),
self.store_and_forward_requester(),
self.dht_requester(),
Arc::clone(&self.node_identity),
Arc::clone(&self.peer_manager),
self.outbound_requester(),
))
.layer(inbound::DhtHandlerLayer::new(
self.config.clone(),
Arc::clone(&self.node_identity),
Arc::clone(&self.peer_manager),
self.discovery_service_requester(),
self.outbound_requester(),
))
.into_inner()
}
pub fn outbound_middleware_layer<S>(
&self,
) -> impl Layer<
S,
Service = impl Service<
DhtOutboundRequest,
Response = (),
Error = PipelineError,
Future = impl Future<Output = Result<(), PipelineError>> + Send,
> + Clone
+ Send,
>
where
S: Service<OutboundMessage, Response = (), Error = PipelineError> + Clone + Send + 'static,
S::Future: Send,
{
ServiceBuilder::new()
.layer(outbound::BroadcastLayer::new(
Arc::clone(&self.node_identity),
self.dht_requester(),
self.discovery_service_requester(),
self.config.network,
))
.layer(MessageLoggingLayer::new(format!(
"Outbound [{}]",
self.node_identity.node_id().short_str()
)))
.layer(outbound::SerializeLayer)
.into_inner()
}
fn unsupported_saf_messages_filter(
&self,
) -> impl tower_filter::Predicate<DhtInboundMessage, Future = future::Ready<Result<(), PipelineError>>> + Clone + Send
{
let node_identity = Arc::clone(&self.node_identity);
move |msg: &DhtInboundMessage| {
if node_identity.has_peer_features(PeerFeatures::DHT_STORE_FORWARD) {
return future::ready(Ok(()));
}
match msg.dht_header.message_type {
DhtMessageType::SafRequestMessages => {
debug!(
"Received store and forward message from PublicKey={}. Store and forward feature is not \
supported by this node. Discarding message.",
msg.source_peer.public_key
);
future::ready(Err(PipelineError::from_debug(
"Message filtered out because store and forward is not supported by this node",
)))
},
_ => future::ready(Ok(())),
}
}
}
}
#[cfg(test)]
mod test {
use crate::{
crypt,
envelope::DhtMessageFlags,
outbound::mock::create_outbound_service_mock,
proto::envelope::DhtMessageType,
test_utils::{
make_client_identity,
make_comms_inbound_message,
make_dht_envelope,
make_node_identity,
make_peer_manager,
},
DhtBuilder,
};
use futures::{channel::mpsc, StreamExt};
use std::{sync::Arc, time::Duration};
use tari_comms::{
message::{MessageExt, MessageTag},
pipeline::SinkService,
test_utils::mocks::create_connectivity_mock,
wrap_in_envelope_body,
};
use tari_shutdown::Shutdown;
use tokio::{task, time};
use tower::{layer::Layer, Service};
#[tokio_macros::test_basic]
async fn stack_unencrypted() {
let node_identity = make_node_identity();
let peer_manager = make_peer_manager();
let (connectivity, _) = create_connectivity_mock();
peer_manager.add_peer(node_identity.to_peer()).await.unwrap();
let (out_tx, _) = mpsc::channel(10);
let shutdown = Shutdown::new();
let dht = DhtBuilder::new(
Arc::clone(&node_identity),
peer_manager,
out_tx,
connectivity,
shutdown.to_signal(),
)
.local_test()
.finish()
.await
.unwrap();
let (out_tx, mut out_rx) = mpsc::channel(10);
let mut service = dht.inbound_middleware_layer().layer(SinkService::new(out_tx));
let msg = wrap_in_envelope_body!(b"secret".to_vec());
let dht_envelope = make_dht_envelope(
&node_identity,
msg.to_encoded_bytes(),
DhtMessageFlags::empty(),
false,
MessageTag::new(),
);
let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into());
let msg = {
service.call(inbound_message).await.unwrap();
let msg = time::timeout(Duration::from_secs(10), out_rx.next())
.await
.unwrap()
.unwrap();
msg.success().unwrap().decode_part::<Vec<u8>>(0).unwrap().unwrap()
};
assert_eq!(msg, b"secret");
}
#[tokio_macros::test_basic]
async fn stack_encrypted() {
let node_identity = make_node_identity();
let peer_manager = make_peer_manager();
let (connectivity, _) = create_connectivity_mock();
peer_manager.add_peer(node_identity.to_peer()).await.unwrap();
let (out_tx, _out_rx) = mpsc::channel(10);
let shutdown = Shutdown::new();
let dht = DhtBuilder::new(
Arc::clone(&node_identity),
peer_manager,
out_tx,
connectivity,
shutdown.to_signal(),
)
.finish()
.await
.unwrap();
let (out_tx, mut out_rx) = mpsc::channel(10);
let mut service = dht.inbound_middleware_layer().layer(SinkService::new(out_tx));
let msg = wrap_in_envelope_body!(b"secret".to_vec());
let dht_envelope = make_dht_envelope(
&node_identity,
msg.to_encoded_bytes(),
DhtMessageFlags::ENCRYPTED,
true,
MessageTag::new(),
);
let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into());
let msg = {
service.call(inbound_message).await.unwrap();
let msg = time::timeout(Duration::from_secs(10), out_rx.next())
.await
.unwrap()
.unwrap();
msg.success().unwrap().decode_part::<Vec<u8>>(0).unwrap().unwrap()
};
assert_eq!(msg, b"secret");
}
#[tokio_macros::test_basic]
async fn stack_forward() {
let node_identity = make_node_identity();
let peer_manager = make_peer_manager();
let shutdown = Shutdown::new();
peer_manager.add_peer(node_identity.to_peer()).await.unwrap();
let (connectivity, _) = create_connectivity_mock();
let (next_service_tx, mut next_service_rx) = mpsc::channel(10);
let (oms_requester, oms_mock) = create_outbound_service_mock(1);
let dht = DhtBuilder::new(
Arc::clone(&node_identity),
peer_manager,
oms_requester.get_mpsc_sender(),
connectivity,
shutdown.to_signal(),
)
.finish()
.await
.unwrap();
let oms_mock_state = oms_mock.get_state();
task::spawn(oms_mock.run());
let mut service = dht.inbound_middleware_layer().layer(SinkService::new(next_service_tx));
let msg = wrap_in_envelope_body!(b"unencrypteable".to_vec());
let node_identity2 = make_node_identity();
let ecdh_key = crypt::generate_ecdh_secret(node_identity2.secret_key(), node_identity2.public_key());
let encrypted_bytes = crypt::encrypt(&ecdh_key, &msg.to_encoded_bytes()).unwrap();
let dht_envelope = make_dht_envelope(
&node_identity,
encrypted_bytes,
DhtMessageFlags::ENCRYPTED,
true,
MessageTag::new(),
);
let origin_mac = dht_envelope.header.as_ref().unwrap().origin_mac.clone();
assert_eq!(origin_mac.is_empty(), false);
let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into());
service.call(inbound_message).await.unwrap();
assert_eq!(oms_mock_state.call_count(), 1);
let (params, _) = oms_mock_state.pop_call().unwrap();
assert_eq!(params.dht_header.unwrap().origin_mac, origin_mac);
assert!(next_service_rx.try_next().is_err());
}
#[tokio_macros::test_basic]
async fn stack_filter_saf_message() {
let node_identity = make_client_identity();
let peer_manager = make_peer_manager();
let (connectivity, _) = create_connectivity_mock();
peer_manager.add_peer(node_identity.to_peer()).await.unwrap();
let (out_tx, _) = mpsc::channel(10);
let shutdown = Shutdown::new();
let dht = DhtBuilder::new(
Arc::clone(&node_identity),
peer_manager,
out_tx,
connectivity,
shutdown.to_signal(),
)
.finish()
.await
.unwrap();
let (next_service_tx, mut next_service_rx) = mpsc::channel(10);
let mut service = dht.inbound_middleware_layer().layer(SinkService::new(next_service_tx));
let msg = wrap_in_envelope_body!(b"secret".to_vec());
let mut dht_envelope = make_dht_envelope(
&node_identity,
msg.to_encoded_bytes(),
DhtMessageFlags::empty(),
false,
MessageTag::new(),
);
dht_envelope.header.as_mut().and_then(|header| {
header.message_type = DhtMessageType::SafStoredMessages as i32;
Some(header)
});
let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into());
service.call(inbound_message).await.unwrap_err();
assert_eq!(
format!("{}", next_service_rx.try_next().unwrap_err()),
"receiver channel is empty"
);
}
}