use crate::{
broadcast_strategy::BroadcastStrategy,
discovery::DhtDiscoveryError,
outbound::{DhtOutboundError, OutboundMessageRequester, SendMessageParams},
proto::{dht::JoinMessage, envelope::DhtMessageType},
storage::{DbConnection, DhtDatabase, DhtMetadataKey, StorageError},
DhtConfig,
};
use chrono::{DateTime, Utc};
use futures::{
channel::{mpsc, mpsc::SendError, oneshot},
future,
future::BoxFuture,
stream::{Fuse, FuturesUnordered},
SinkExt,
StreamExt,
};
use log::*;
use std::{cmp, fmt, fmt::Display, sync::Arc};
use tari_comms::{
connection_manager::ConnectionManagerError,
connectivity::{ConnectivityError, ConnectivityRequester, ConnectivitySelection},
peer_manager::{NodeId, NodeIdentity, PeerFeatures, PeerManager, PeerManagerError, PeerQuery, PeerQuerySortBy},
};
use tari_shutdown::ShutdownSignal;
use tari_utilities::message_format::{MessageFormat, MessageFormatError};
use thiserror::Error;
use tokio::task;
use ttl_cache::TtlCache;
const LOG_TARGET: &str = "comms::dht::actor";
#[derive(Debug, Error)]
pub enum DhtActorError {
#[error("MPSC channel is disconnected")]
ChannelDisconnected,
#[error("MPSC sender was unable to send because the channel buffer is full")]
SendBufferFull,
#[error("Reply sender canceled the request")]
ReplyCanceled,
#[error("PeerManagerError: {0}")]
PeerManagerError(#[from] PeerManagerError),
#[error("Failed to broadcast join message: {0}")]
FailedToBroadcastJoinMessage(DhtOutboundError),
#[error("DiscoveryError: {0}")]
DiscoveryError(#[from] DhtDiscoveryError),
#[error("StorageError: {0}")]
StorageError(#[from] StorageError),
#[error("StoredValueFailedToDeserialize: {0}")]
StoredValueFailedToDeserialize(MessageFormatError),
#[error("FailedToSerializeValue: {0}")]
FailedToSerializeValue(MessageFormatError),
#[error("ConnectionManagerError: {0}")]
ConnectionManagerError(#[from] ConnectionManagerError),
#[error("ConnectivityError: {0}")]
ConnectivityError(#[from] ConnectivityError),
#[error("Connectivity event stream closed")]
ConnectivityEventStreamClosed,
}
impl From<SendError> for DhtActorError {
fn from(err: SendError) -> Self {
if err.is_disconnected() {
DhtActorError::ChannelDisconnected
} else if err.is_full() {
DhtActorError::SendBufferFull
} else {
unreachable!();
}
}
}
#[derive(Debug)]
pub enum DhtRequest {
SendJoin,
MsgHashCacheInsert(Vec<u8>, oneshot::Sender<bool>),
SelectPeers(BroadcastStrategy, oneshot::Sender<Vec<NodeId>>),
GetMetadata(DhtMetadataKey, oneshot::Sender<Result<Option<Vec<u8>>, DhtActorError>>),
SetMetadata(DhtMetadataKey, Vec<u8>, oneshot::Sender<Result<(), DhtActorError>>),
}
impl Display for DhtRequest {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use DhtRequest::*;
match self {
SendJoin => f.write_str("SendJoin"),
MsgHashCacheInsert(_, _) => f.write_str("MsgHashCacheInsert"),
SelectPeers(s, _) => f.write_str(&format!("SelectPeers (Strategy={})", s)),
GetMetadata(key, _) => f.write_str(&format!("GetMetadata (key={})", key)),
SetMetadata(key, value, _) => {
f.write_str(&format!("SetMetadata (key={}, value={} bytes)", key, value.len()))
},
}
}
}
#[derive(Clone)]
pub struct DhtRequester {
sender: mpsc::Sender<DhtRequest>,
}
impl DhtRequester {
pub fn new(sender: mpsc::Sender<DhtRequest>) -> Self {
Self { sender }
}
pub async fn send_join(&mut self) -> Result<(), DhtActorError> {
self.sender.send(DhtRequest::SendJoin).await.map_err(Into::into)
}
pub async fn select_peers(&mut self, broadcast_strategy: BroadcastStrategy) -> Result<Vec<NodeId>, DhtActorError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.sender
.send(DhtRequest::SelectPeers(broadcast_strategy, reply_tx))
.await?;
reply_rx.await.map_err(|_| DhtActorError::ReplyCanceled)
}
pub async fn insert_message_hash(&mut self, signature: Vec<u8>) -> Result<bool, DhtActorError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.sender
.send(DhtRequest::MsgHashCacheInsert(signature, reply_tx))
.await?;
reply_rx.await.map_err(|_| DhtActorError::ReplyCanceled)
}
pub async fn get_metadata<T: MessageFormat>(&mut self, key: DhtMetadataKey) -> Result<Option<T>, DhtActorError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.sender.send(DhtRequest::GetMetadata(key, reply_tx)).await?;
match reply_rx.await.map_err(|_| DhtActorError::ReplyCanceled)?? {
Some(bytes) => T::from_binary(&bytes)
.map(Some)
.map_err(DhtActorError::StoredValueFailedToDeserialize),
None => Ok(None),
}
}
pub async fn set_metadata<T: MessageFormat>(&mut self, key: DhtMetadataKey, value: T) -> Result<(), DhtActorError> {
let (reply_tx, reply_rx) = oneshot::channel();
let bytes = value.to_binary().map_err(DhtActorError::FailedToSerializeValue)?;
self.sender.send(DhtRequest::SetMetadata(key, bytes, reply_tx)).await?;
reply_rx.await.map_err(|_| DhtActorError::ReplyCanceled)?
}
}
pub struct DhtActor {
node_identity: Arc<NodeIdentity>,
peer_manager: Arc<PeerManager>,
database: DhtDatabase,
outbound_requester: OutboundMessageRequester,
connectivity: ConnectivityRequester,
config: DhtConfig,
shutdown_signal: Option<ShutdownSignal>,
request_rx: Fuse<mpsc::Receiver<DhtRequest>>,
msg_hash_cache: TtlCache<Vec<u8>, ()>,
}
impl DhtActor {
#[allow(clippy::too_many_arguments)]
pub fn new(
config: DhtConfig,
conn: DbConnection,
node_identity: Arc<NodeIdentity>,
peer_manager: Arc<PeerManager>,
connectivity: ConnectivityRequester,
outbound_requester: OutboundMessageRequester,
request_rx: mpsc::Receiver<DhtRequest>,
shutdown_signal: ShutdownSignal,
) -> Self
{
Self {
msg_hash_cache: TtlCache::new(config.msg_hash_cache_capacity),
config,
database: DhtDatabase::new(conn),
outbound_requester,
peer_manager,
connectivity,
node_identity,
shutdown_signal: Some(shutdown_signal),
request_rx: request_rx.fuse(),
}
}
pub fn spawn(self) {
task::spawn(async move {
if let Err(err) = self.run().await {
error!(target: LOG_TARGET, "DhtActor failed to start with error: {:?}", err);
}
});
}
async fn run(mut self) -> Result<(), DhtActorError> {
let offline_ts = self
.database
.get_metadata_value::<DateTime<Utc>>(DhtMetadataKey::OfflineTimestamp)
.await
.ok()
.flatten();
info!(
target: LOG_TARGET,
"DhtActor started. {}",
offline_ts
.map(|dt| format!("Dht has been offline since '{}'", dt))
.unwrap_or_else(String::new)
);
let mut pending_jobs = FuturesUnordered::new();
let mut shutdown_signal = self
.shutdown_signal
.take()
.expect("DhtActor initialized without shutdown_signal");
loop {
futures::select! {
request = self.request_rx.select_next_some() => {
trace!(target: LOG_TARGET, "DhtActor received request: {}", request);
pending_jobs.push(self.request_handler(request));
},
result = pending_jobs.select_next_some() => {
if let Err(err) = result {
debug!(target: LOG_TARGET, "Error when handling DHT request message. {}", err);
}
},
_ = shutdown_signal => {
info!(target: LOG_TARGET, "DhtActor is shutting down because it received a shutdown signal.");
self.mark_shutdown_time().await;
break Ok(());
},
}
}
}
async fn mark_shutdown_time(&self) {
if let Err(err) = self
.database
.set_metadata_value(DhtMetadataKey::OfflineTimestamp, Utc::now())
.await
{
warn!(target: LOG_TARGET, "Failed to mark offline time: {:?}", err);
}
}
fn request_handler(&mut self, request: DhtRequest) -> BoxFuture<'static, Result<(), DhtActorError>> {
use DhtRequest::*;
match request {
SendJoin => {
let node_identity = Arc::clone(&self.node_identity);
let outbound_requester = self.outbound_requester.clone();
Box::pin(Self::broadcast_join(node_identity, outbound_requester))
},
MsgHashCacheInsert(hash, reply_tx) => {
let already_exists = self
.msg_hash_cache
.insert(hash, (), self.config.msg_hash_cache_ttl)
.is_some();
let result = reply_tx.send(already_exists).map_err(|_| DhtActorError::ReplyCanceled);
Box::pin(future::ready(result))
},
SelectPeers(broadcast_strategy, reply_tx) => {
let peer_manager = Arc::clone(&self.peer_manager);
let node_identity = Arc::clone(&self.node_identity);
let connectivity = self.connectivity.clone();
let config = self.config.clone();
Box::pin(async move {
match Self::select_peers(config, node_identity, peer_manager, connectivity, broadcast_strategy)
.await
{
Ok(peers) => reply_tx.send(peers).map_err(|_| DhtActorError::ReplyCanceled),
Err(err) => {
warn!(target: LOG_TARGET, "Peer selection failed: {:?}", err);
reply_tx.send(Vec::new()).map_err(|_| DhtActorError::ReplyCanceled)
},
}
})
},
GetMetadata(key, reply_tx) => {
let db = self.database.clone();
Box::pin(async move {
let _ = reply_tx.send(db.get_metadata_value_bytes(key).await.map_err(Into::into));
Ok(())
})
},
SetMetadata(key, value, reply_tx) => {
let db = self.database.clone();
Box::pin(async move {
match db.set_metadata_value_bytes(key, value).await {
Ok(_) => {
debug!(target: LOG_TARGET, "Dht metadata '{}' set", key);
let _ = reply_tx.send(Ok(()));
},
Err(err) => {
warn!(target: LOG_TARGET, "Unable to set metadata because {:?}", err);
let _ = reply_tx.send(Err(err.into()));
},
}
Ok(())
})
},
}
}
async fn broadcast_join(
node_identity: Arc<NodeIdentity>,
mut outbound_requester: OutboundMessageRequester,
) -> Result<(), DhtActorError>
{
let message = JoinMessage::from(&node_identity);
debug!(target: LOG_TARGET, "Sending Join message to closest peers");
outbound_requester
.send_message_no_header(
SendMessageParams::new()
.closest(node_identity.node_id().clone(), vec![])
.with_destination(node_identity.node_id().clone().into())
.with_dht_message_type(DhtMessageType::Join)
.force_origin()
.finish(),
message,
)
.await
.map_err(DhtActorError::FailedToBroadcastJoinMessage)?;
Ok(())
}
async fn select_peers(
config: DhtConfig,
node_identity: Arc<NodeIdentity>,
peer_manager: Arc<PeerManager>,
mut connectivity: ConnectivityRequester,
broadcast_strategy: BroadcastStrategy,
) -> Result<Vec<NodeId>, DhtActorError>
{
use BroadcastStrategy::*;
match broadcast_strategy {
DirectNodeId(node_id) => {
peer_manager
.direct_identity_node_id(&node_id)
.await
.map(|peer| peer.map(|p| vec![p.node_id]).unwrap_or_default())
.map_err(Into::into)
},
DirectPublicKey(public_key) => {
peer_manager
.direct_identity_public_key(&public_key)
.await
.map(|peer| peer.map(|p| vec![p.node_id]).unwrap_or_default())
.map_err(Into::into)
},
Flood(exclude) => {
let peers = connectivity
.select_connections(ConnectivitySelection::all_nodes(exclude))
.await?;
Ok(peers.into_iter().map(|p| p.peer_node_id().clone()).collect())
},
Closest(closest_request) => {
let connections = connectivity
.select_connections(ConnectivitySelection::closest_to(
closest_request.node_id.clone(),
config.broadcast_factor,
closest_request.excluded_peers.clone(),
))
.await?;
let mut candidates = connections
.iter()
.map(|conn| conn.peer_node_id())
.cloned()
.collect::<Vec<_>>();
if !closest_request.connected_only {
let excluded = closest_request
.excluded_peers
.iter()
.chain(candidates.iter())
.cloned()
.collect::<Vec<_>>();
let n = cmp::max(config.broadcast_factor.saturating_sub(candidates.len()), 2);
let additional = Self::select_closest_peers_for_propagation(
&peer_manager,
&closest_request.node_id,
n,
&excluded,
PeerFeatures::MESSAGE_PROPAGATION,
)
.await?;
candidates.extend(additional);
}
Ok(candidates)
},
Random(n, excluded) => {
Ok(peer_manager
.random_peers(n, &excluded)
.await?
.into_iter()
.map(|p| p.node_id)
.collect())
},
Broadcast(exclude) => {
let connections = connectivity
.select_connections(ConnectivitySelection::random_nodes(
config.broadcast_factor,
exclude.clone(),
))
.await?;
let candidates = connections
.iter()
.map(|c| c.peer_node_id())
.cloned()
.collect::<Vec<_>>();
if candidates.is_empty() {
warn!(
target: LOG_TARGET,
"Broadcast requested but there are no node peer connections available"
);
}
debug!(
target: LOG_TARGET,
"{} candidate(s) selected for broadcast",
candidates.len()
);
Ok(candidates)
},
Propagate(destination, exclude) => {
let dest_node_id = destination
.node_id()
.map(Clone::clone)
.or_else(|| destination.public_key().and_then(|pk| NodeId::from_key(pk).ok()));
let connections = match dest_node_id {
Some(node_id) => {
let dest_connection = connectivity.get_connection(node_id.clone()).await?;
let dest_connection = dest_connection.filter(|c| !exclude.contains(c.peer_node_id()));
match dest_connection {
Some(conn) => {
vec![conn]
},
None => {
let mut connections = connectivity
.select_connections(ConnectivitySelection::closest_to(
node_id.clone(),
config.num_neighbouring_nodes,
exclude.clone(),
))
.await?;
if connections.len() >= config.propagation_factor {
let dist_from_dest = node_identity.node_id().distance(&node_id);
let before_len = connections.len();
connections = connections
.into_iter()
.filter(|conn| conn.peer_node_id().distance(&node_id) <= dist_from_dest)
.collect::<Vec<_>>();
debug!(
target: LOG_TARGET,
"Filtered out {} node(s) that are further away than this node.",
before_len - connections.len()
);
}
connections.truncate(config.propagation_factor);
connections
},
}
},
None => {
debug!(
target: LOG_TARGET,
"No destination for propagation, sending to {} random peers", config.propagation_factor
);
connectivity
.select_connections(ConnectivitySelection::random_nodes(
config.propagation_factor,
exclude.clone(),
))
.await?
},
};
if connections.is_empty() {
warn!(
target: LOG_TARGET,
"Propagation requested but there are no node peer connections available"
);
}
let candidates = connections
.iter()
.map(|c| c.peer_node_id())
.cloned()
.collect::<Vec<_>>();
debug!(
target: LOG_TARGET,
"{} candidate(s) selected for propagation to {}",
candidates.len(),
destination
);
trace!(
target: LOG_TARGET,
"(ThisNode = {}) Candidates are {}",
node_identity.node_id().short_str(),
candidates.iter().map(|n| n.short_str()).collect::<Vec<_>>().join(", ")
);
Ok(candidates)
},
}
}
async fn select_closest_peers_for_propagation(
peer_manager: &PeerManager,
node_id: &NodeId,
n: usize,
excluded_peers: &[NodeId],
features: PeerFeatures,
) -> Result<Vec<NodeId>, DhtActorError>
{
let mut connect_ineligable_count = 0;
let mut banned_count = 0;
let mut excluded_count = 0;
let mut filtered_out_node_count = 0;
let query = PeerQuery::new()
.select_where(|peer| {
if peer.is_banned() {
banned_count += 1;
return false;
}
if !peer.features.contains(features) {
filtered_out_node_count += 1;
return false;
}
if peer.is_offline() {
connect_ineligable_count += 1;
return false;
}
let is_excluded = excluded_peers.contains(&peer.node_id);
if is_excluded {
excluded_count += 1;
return false;
}
true
})
.sort_by(PeerQuerySortBy::DistanceFrom(&node_id))
.limit(n);
let peers = peer_manager.perform_query(query).await?;
let total_excluded = banned_count + connect_ineligable_count + excluded_count + filtered_out_node_count;
if total_excluded > 0 {
debug!(
target: LOG_TARGET,
"👨👧👦 Closest Peer Selection: {num_peers} peer(s) selected, {total} peer(s) not selected, {banned} \
banned, {filtered_out} not communication node, {not_connectable} are not connectable, {excluded} \
explicitly excluded",
num_peers = peers.len(),
total = total_excluded,
banned = banned_count,
filtered_out = filtered_out_node_count,
not_connectable = connect_ineligable_count,
excluded = excluded_count
);
}
Ok(peers.into_iter().map(|p| p.node_id).collect())
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{
broadcast_strategy::BroadcastClosestRequest,
envelope::NodeDestination,
test_utils::{build_peer_manager, make_client_identity, make_node_identity},
};
use chrono::{DateTime, Utc};
use tari_comms::test_utils::mocks::{create_connectivity_mock, create_peer_connection_mock_pair};
use tari_shutdown::Shutdown;
use tari_test_utils::random;
async fn db_connection() -> DbConnection {
let conn = DbConnection::connect_memory(random::string(8)).await.unwrap();
conn.migrate().await.unwrap();
conn
}
#[tokio_macros::test_basic]
async fn send_join_request() {
let node_identity = make_node_identity();
let peer_manager = build_peer_manager();
let (out_tx, mut out_rx) = mpsc::channel(1);
let (connectivity_manager, mock) = create_connectivity_mock();
mock.spawn();
let (actor_tx, actor_rx) = mpsc::channel(1);
let mut requester = DhtRequester::new(actor_tx);
let outbound_requester = OutboundMessageRequester::new(out_tx);
let shutdown = Shutdown::new();
let actor = DhtActor::new(
Default::default(),
db_connection().await,
node_identity,
peer_manager,
connectivity_manager,
outbound_requester,
actor_rx,
shutdown.to_signal(),
);
actor.spawn();
requester.send_join().await.unwrap();
let (params, _) = unwrap_oms_send_msg!(out_rx.next().await.unwrap());
assert_eq!(params.dht_message_type, DhtMessageType::Join);
}
#[tokio_macros::test_basic]
async fn insert_message_signature() {
let node_identity = make_node_identity();
let peer_manager = build_peer_manager();
let (connectivity_manager, mock) = create_connectivity_mock();
mock.spawn();
let (out_tx, _) = mpsc::channel(1);
let (actor_tx, actor_rx) = mpsc::channel(1);
let mut requester = DhtRequester::new(actor_tx);
let outbound_requester = OutboundMessageRequester::new(out_tx);
let shutdown = Shutdown::new();
let actor = DhtActor::new(
Default::default(),
db_connection().await,
node_identity,
peer_manager,
connectivity_manager,
outbound_requester,
actor_rx,
shutdown.to_signal(),
);
actor.spawn();
let signature = vec![1u8, 2, 3];
let is_dup = requester.insert_message_hash(signature.clone()).await.unwrap();
assert_eq!(is_dup, false);
let is_dup = requester.insert_message_hash(signature).await.unwrap();
assert_eq!(is_dup, true);
let is_dup = requester.insert_message_hash(Vec::new()).await.unwrap();
assert_eq!(is_dup, false);
}
#[tokio_macros::test_basic]
async fn select_peers() {
let node_identity = make_node_identity();
let peer_manager = build_peer_manager();
let client_node_identity = make_client_identity();
peer_manager.add_peer(client_node_identity.to_peer()).await.unwrap();
let (connectivity_manager, mock) = create_connectivity_mock();
let connectivity_manager_mock_state = mock.get_shared_state();
mock.spawn();
let (conn_in, _, conn_out, _) =
create_peer_connection_mock_pair(1, client_node_identity.to_peer(), node_identity.to_peer()).await;
connectivity_manager_mock_state.add_active_connection(conn_in).await;
peer_manager.add_peer(make_node_identity().to_peer()).await.unwrap();
let (out_tx, _) = mpsc::channel(1);
let (actor_tx, actor_rx) = mpsc::channel(1);
let mut requester = DhtRequester::new(actor_tx);
let outbound_requester = OutboundMessageRequester::new(out_tx);
let shutdown = Shutdown::new();
let actor = DhtActor::new(
Default::default(),
db_connection().await,
Arc::clone(&node_identity),
peer_manager,
connectivity_manager,
outbound_requester,
actor_rx,
shutdown.to_signal(),
);
actor.spawn();
let peers = requester
.select_peers(BroadcastStrategy::Broadcast(Vec::new()))
.await
.unwrap();
assert_eq!(peers.len(), 0);
connectivity_manager_mock_state
.set_selected_connections(vec![conn_out.clone()])
.await;
let peers = requester
.select_peers(BroadcastStrategy::Broadcast(Vec::new()))
.await
.unwrap();
assert_eq!(peers.len(), 1);
let peers = requester
.select_peers(BroadcastStrategy::Propagate(NodeDestination::Unknown, Vec::new()))
.await
.unwrap();
assert_eq!(peers.len(), 1);
let peers = requester
.select_peers(BroadcastStrategy::Propagate(
conn_out.peer_node_id().clone().into(),
Vec::new(),
))
.await
.unwrap();
assert_eq!(peers.len(), 1);
let send_request = Box::new(BroadcastClosestRequest {
node_id: node_identity.node_id().clone(),
excluded_peers: vec![],
connected_only: false,
});
let peers = requester
.select_peers(BroadcastStrategy::Closest(send_request))
.await
.unwrap();
assert_eq!(peers.len(), 2);
let peers = requester
.select_peers(BroadcastStrategy::DirectNodeId(Box::new(
client_node_identity.node_id().clone(),
)))
.await
.unwrap();
assert_eq!(peers.len(), 1);
}
#[tokio_macros::test_basic]
async fn get_and_set_metadata() {
let node_identity = make_node_identity();
let peer_manager = build_peer_manager();
let (out_tx, _out_rx) = mpsc::channel(1);
let (actor_tx, actor_rx) = mpsc::channel(1);
let (connectivity_manager, mock) = create_connectivity_mock();
mock.spawn();
let mut requester = DhtRequester::new(actor_tx);
let outbound_requester = OutboundMessageRequester::new(out_tx);
let mut shutdown = Shutdown::new();
let actor = DhtActor::new(
Default::default(),
db_connection().await,
node_identity,
peer_manager,
connectivity_manager,
outbound_requester,
actor_rx,
shutdown.to_signal(),
);
actor.spawn();
assert!(requester
.get_metadata::<DateTime<Utc>>(DhtMetadataKey::OfflineTimestamp)
.await
.unwrap()
.is_none());
let ts = Utc::now();
requester
.set_metadata(DhtMetadataKey::OfflineTimestamp, ts)
.await
.unwrap();
let got_ts = requester
.get_metadata::<DateTime<Utc>>(DhtMetadataKey::OfflineTimestamp)
.await
.unwrap()
.unwrap();
assert_eq!(got_ts, ts);
let ts = Utc::now().checked_add_signed(chrono::Duration::seconds(123)).unwrap();
requester
.set_metadata(DhtMetadataKey::OfflineTimestamp, ts)
.await
.unwrap();
let got_ts = requester
.get_metadata::<DateTime<Utc>>(DhtMetadataKey::OfflineTimestamp)
.await
.unwrap()
.unwrap();
assert_eq!(got_ts, ts);
shutdown.trigger().unwrap();
}
}