pub mod auth;
#[cfg(feature = "connection-manager")]
pub mod connection_manager;
pub mod dispatch;
mod dispatch_proto;
pub mod handlers;
pub mod peer;
#[cfg(feature = "network-peer-manager")]
pub mod peer_manager;
#[cfg(feature = "network-ref-map")]
pub(crate) mod ref_map;
pub(crate) mod reply;
pub mod sender;
use protobuf::Message;
use uuid::Uuid;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
use std::time::Duration;
use crate::collections::BiHashMap;
use crate::mesh::{
AddError, Envelope, Mesh, RecvError as MeshRecvError, RecvTimeoutError as MeshRecvTimeoutError,
RemoveError, SendError as MeshSendError,
};
use crate::protos::network::{NetworkHeartbeat, NetworkMessage, NetworkMessageType};
use crate::transport::Connection;
#[derive(Debug)]
pub struct NetworkMessageWrapper {
peer_id: String,
payload: Vec<u8>,
}
impl NetworkMessageWrapper {
pub fn new(peer_id: String, payload: Vec<u8>) -> Self {
NetworkMessageWrapper { peer_id, payload }
}
pub fn peer_id(&self) -> &str {
&self.peer_id
}
pub fn payload(&self) -> &[u8] {
&self.payload
}
}
pub trait DisconnectListener: Send {
fn on_disconnect(&self, peer_id: &str);
}
impl<F> DisconnectListener for F
where
F: Fn(&str) + Send,
{
fn on_disconnect(&self, peer_id: &str) {
(*self)(peer_id)
}
}
struct PeerMap {
peers: BiHashMap<String, String>,
redirects: HashMap<String, String>,
endpoints: BiHashMap<String, String>,
}
impl PeerMap {
fn new() -> Self {
PeerMap {
peers: BiHashMap::new(),
redirects: HashMap::new(),
endpoints: BiHashMap::new(),
}
}
fn peer_ids(&self) -> Vec<String> {
self.peers
.keys()
.map(std::string::ToString::to_string)
.collect()
}
fn insert(&mut self, peer_id: String, mesh_id: String, endpoint: String) {
self.peers.insert(peer_id.clone(), mesh_id);
self.endpoints.insert(peer_id, endpoint);
}
fn remove(&mut self, peer_id: &str) -> Option<String> {
info!("Removing peer: {}", peer_id);
self.redirects
.retain(|_, target_peer_id| target_peer_id != peer_id);
self.endpoints.remove_by_key(peer_id);
self.peers
.remove_by_key(peer_id)
.map(|(_, mesh_id)| mesh_id)
}
fn update(&mut self, old_peer_id: String, new_peer_id: String) -> Result<(), PeerUpdateError> {
if let Some((_, mesh_id)) = self.peers.remove_by_key(&old_peer_id) {
self.peers.insert(new_peer_id.clone(), mesh_id);
if let Some((_, endpoint)) = self.endpoints.remove_by_key(&old_peer_id) {
self.endpoints.insert(new_peer_id.clone(), endpoint);
}
for (_, v) in self
.redirects
.iter_mut()
.filter(|(_, v)| **v == old_peer_id)
{
*v = new_peer_id.clone()
}
self.redirects.insert(old_peer_id, new_peer_id);
Ok(())
} else {
Err(PeerUpdateError {
old_peer_id,
new_peer_id,
})
}
}
fn get_mesh_id(&self, peer_id: &str) -> Option<&String> {
self.redirects
.get(peer_id)
.and_then(|target_peer_id| self.peers.get_by_key(target_peer_id))
.or_else(|| self.peers.get_by_key(peer_id))
}
fn get_peer_id(&self, mesh_id: &str) -> Option<&String> {
self.peers.get_by_value(mesh_id)
}
fn get_peer_endpoint(&self, peer_id: &str) -> Option<String> {
let endpoint_opt = self
.redirects
.get(peer_id)
.and_then(|target_peer_id| self.endpoints.get_by_key(target_peer_id))
.or_else(|| self.endpoints.get_by_key(peer_id));
endpoint_opt.cloned()
}
fn get_peer_by_endpoint(&self, endpoint: &str) -> Option<String> {
self.endpoints.get_by_value(endpoint).cloned()
}
}
#[derive(Clone)]
pub struct Network {
peers: Arc<RwLock<PeerMap>>,
mesh: Mesh,
disconnect_listeners: Arc<Mutex<Vec<Box<dyn DisconnectListener>>>>,
}
impl Network {
pub fn new(mesh: Mesh, heartbeat_interval: u64) -> Result<Self, NetworkStartUpError> {
let network = Network {
peers: Arc::new(RwLock::new(PeerMap::new())),
mesh,
disconnect_listeners: Arc::new(Mutex::new(vec![])),
};
if heartbeat_interval != 0 {
let heartbeat_network = network.clone();
let heartbeat = NetworkHeartbeat::new().write_to_bytes().map_err(|_| {
NetworkStartUpError("cannot create NetworkHeartbeat message".to_string())
})?;
let mut heartbeat_message = NetworkMessage::new();
heartbeat_message.set_message_type(NetworkMessageType::NETWORK_HEARTBEAT);
heartbeat_message.set_payload(heartbeat);
let heartbeat_bytes = heartbeat_message
.write_to_bytes()
.map_err(|_| NetworkStartUpError("cannot create NetworkMessage".to_string()))?;
let _ = thread::spawn(move || {
let interval = Duration::from_secs(heartbeat_interval);
thread::sleep(interval);
loop {
let peers = rwlock_read_unwrap!(heartbeat_network.peers).peer_ids();
for peer in peers {
heartbeat_network
.send(&peer, &heartbeat_bytes)
.unwrap_or_else(|err| {
error!("Unable to send heartbeat to {}: {:?}", peer, err)
});
}
thread::sleep(interval);
}
});
}
Ok(network)
}
pub fn peer_ids(&self) -> Vec<String> {
rwlock_read_unwrap!(self.peers).peer_ids()
}
pub fn get_peer_endpoint(&self, peer_id: &str) -> Option<String> {
rwlock_read_unwrap!(self.peers).get_peer_endpoint(peer_id)
}
pub fn get_peer_by_endpoint(&self, endpoint: &str) -> Option<String> {
rwlock_read_unwrap!(self.peers).get_peer_by_endpoint(endpoint)
}
pub fn add_disconnect_listener(&self, listener: Box<dyn DisconnectListener>) {
match self.disconnect_listeners.lock() {
Ok(mut listeners) => {
listeners.push(listener);
}
Err(_) => {
error!("Unable to add disconnect listener due to poisoned lock");
}
}
}
fn notify_disconnect_listeners(&self, peer_id: &str) {
match self.disconnect_listeners.lock() {
Ok(listeners) => {
listeners.iter().for_each(|listener| {
listener.on_disconnect(peer_id);
});
}
Err(_) => error!("Unable to notify disconnect listeners due to poisoned lock"),
}
}
pub fn add_connection(
&self,
connection: Box<dyn Connection>,
) -> Result<String, ConnectionError> {
let mut peers = rwlock_write_unwrap!(self.peers);
let endpoint = connection.remote_endpoint();
let mesh_id = format!("{}", Uuid::new_v4());
self.mesh.add(connection, mesh_id.clone())?;
let peer_id = format!("temp-{}", Uuid::new_v4());
peers.insert(peer_id.clone(), mesh_id, endpoint);
Ok(peer_id)
}
pub fn remove_connection(&self, peer_id: &str) -> Result<(), ConnectionError> {
if let Some(mesh_id) = rwlock_write_unwrap!(self.peers).remove(peer_id) {
let mut connection = self.mesh.remove(&mesh_id)?;
match connection.disconnect() {
Ok(_) => (),
Err(err) => warn!("Unable to disconnect from {}: {:?}", peer_id, err),
}
self.notify_disconnect_listeners(peer_id);
}
Ok(())
}
pub fn add_peer(
&self,
peer_id: String,
connection: Box<dyn Connection>,
) -> Result<(), ConnectionError> {
let mut peers = rwlock_write_unwrap!(self.peers);
let endpoint = connection.remote_endpoint();
let mesh_id = format!("{}", Uuid::new_v4());
self.mesh.add(connection, mesh_id.clone())?;
peers.insert(peer_id, mesh_id, endpoint);
Ok(())
}
pub fn update_peer_id(&self, old_id: String, new_id: String) -> Result<(), PeerUpdateError> {
rwlock_write_unwrap!(self.peers).update(old_id, new_id)
}
pub fn send(&self, peer_id: &str, msg: &[u8]) -> Result<(), SendError> {
let mesh_id = match rwlock_read_unwrap!(self.peers).get_mesh_id(peer_id) {
Some(mesh_id) => mesh_id.to_string(),
None => {
return Err(SendError::NoPeerError(peer_id.to_string()));
}
};
match self.mesh.send(Envelope::new(mesh_id, msg.to_vec())) {
Ok(()) => (),
Err(MeshSendError::Disconnected(err)) => {
rwlock_write_unwrap!(self.peers).remove(peer_id);
self.notify_disconnect_listeners(peer_id);
return Err(SendError::from(MeshSendError::Disconnected(err)));
}
Err(err) => return Err(SendError::from(err)),
}
Ok(())
}
pub fn recv(&self) -> Result<NetworkMessageWrapper, RecvError> {
let envelope = self.mesh.recv()?;
let peer_id = match rwlock_read_unwrap!(self.peers).get_peer_id(envelope.id()) {
Some(peer_id) => peer_id.to_string(),
None => {
return Err(RecvError::NoPeerError(format!(
"Recv Error: No Peer with mesh id {} found",
envelope.id()
)));
}
};
Ok(NetworkMessageWrapper::new(peer_id, envelope.take_payload()))
}
pub fn recv_timeout(
&self,
timeout: Duration,
) -> Result<NetworkMessageWrapper, RecvTimeoutError> {
let envelope = self.mesh.recv_timeout(timeout)?;
let peer_id = match rwlock_read_unwrap!(self.peers).get_peer_id(envelope.id()) {
Some(peer_id) => peer_id.to_string(),
None => {
return Err(RecvTimeoutError::NoPeerError(format!(
"Recv Error: No Peer with mesh id {} found",
envelope.id()
)));
}
};
Ok(NetworkMessageWrapper::new(peer_id, envelope.take_payload()))
}
}
#[derive(Debug)]
pub enum RecvError {
NoPeerError(String),
MeshError(String),
}
impl From<MeshRecvError> for RecvError {
fn from(recv_error: MeshRecvError) -> Self {
RecvError::MeshError(format!("Recv Error: {:?}", recv_error))
}
}
#[derive(Debug)]
pub enum RecvTimeoutError {
NoPeerError(String),
Timeout,
Disconnected,
PoisonedLock,
}
impl From<MeshRecvTimeoutError> for RecvTimeoutError {
fn from(recv_error: MeshRecvTimeoutError) -> Self {
match recv_error {
MeshRecvTimeoutError::Timeout => RecvTimeoutError::Timeout,
MeshRecvTimeoutError::Disconnected => RecvTimeoutError::Disconnected,
MeshRecvTimeoutError::PoisonedLock => RecvTimeoutError::PoisonedLock,
}
}
}
#[derive(Debug)]
pub enum SendError {
NoPeerError(String),
MeshError(String),
}
impl std::error::Error for SendError {}
impl std::fmt::Display for SendError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
SendError::NoPeerError(msg) => write!(f, "no peer with peer_id {} found", msg),
SendError::MeshError(msg) => write!(f, "received error from mesh: {}", msg),
}
}
}
impl From<MeshSendError> for SendError {
fn from(send_error: MeshSendError) -> Self {
SendError::MeshError(send_error.to_string())
}
}
#[derive(Debug)]
pub enum ConnectionError {
AddError(String),
RemoveError(String),
}
impl std::error::Error for ConnectionError {}
impl std::fmt::Display for ConnectionError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
ConnectionError::AddError(msg) => write!(f, "unable to add connection: {}", msg),
ConnectionError::RemoveError(msg) => write!(f, "unable to remove connection: {}", msg),
}
}
}
#[derive(Debug)]
pub struct NetworkStartUpError(String);
impl std::error::Error for NetworkStartUpError {}
impl std::fmt::Display for NetworkStartUpError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "network failed to startup: {}", self.0)
}
}
impl From<AddError> for ConnectionError {
fn from(add_error: AddError) -> Self {
ConnectionError::AddError(format!("Add Error: {:?}", add_error))
}
}
impl From<RemoveError> for ConnectionError {
fn from(remove_error: RemoveError) -> Self {
ConnectionError::RemoveError(format!("Remove Error: {:?}", remove_error))
}
}
#[derive(Debug)]
pub struct PeerUpdateError {
pub old_peer_id: String,
pub new_peer_id: String,
}
impl std::error::Error for PeerUpdateError {}
impl std::fmt::Display for PeerUpdateError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"unable to update peer {} to {}",
self.old_peer_id, self.new_peer_id
)
}
}
#[cfg(test)]
pub mod tests {
use super::*;
use crate::transport::socket::TcpTransport;
use crate::transport::Transport;
use std::fmt::Debug;
use std::thread;
fn assert_ok<T, E: Debug>(result: Result<T, E>) -> T {
match result {
Ok(ok) => ok,
Err(err) => panic!("Expected Ok(...), got Err({:?})", err),
}
}
#[test]
fn test_network() {
let mesh_one = Mesh::new(5, 5);
let network_one = Network::new(mesh_one, 2).unwrap();
let mut transport = TcpTransport::default();
let mut listener = assert_ok(transport.listen("127.0.0.1:0"));
let endpoint = listener.endpoint();
thread::spawn(move || {
let mesh_two = Mesh::new(5, 5);
let network_two = Network::new(mesh_two, 2).unwrap();
let connection = assert_ok(transport.connect(&endpoint));
assert_ok(network_two.add_connection(connection));
let message = assert_ok(network_two.recv());
assert_eq!(b"345", message.payload());
let peer_id = String::from_utf8(message.payload().to_vec()).unwrap();
assert_ok(network_two.update_peer_id(message.peer_id().into(), peer_id.clone()));
assert_eq!(vec![peer_id.clone()], network_two.peer_ids());
assert_ok(network_two.send(&peer_id, b"hello_world"));
});
let connection = assert_ok(listener.accept());
let remote_endpoint = connection.remote_endpoint();
assert_ok(network_one.add_peer("123".into(), connection));
assert_eq!(
Some("123".into()),
network_one.get_peer_by_endpoint(&remote_endpoint)
);
assert_ok(network_one.send("123".into(), b"345"));
let message = assert_ok(network_one.recv());
assert_eq!("123", message.peer_id());
assert_eq!(b"hello_world", message.payload());
let heartbeat = NetworkHeartbeat::new().write_to_bytes().unwrap();
let mut heartbeat_message = NetworkMessage::new();
heartbeat_message.set_message_type(NetworkMessageType::NETWORK_HEARTBEAT);
heartbeat_message.set_payload(heartbeat);
let heartbeat_bytes = heartbeat_message.write_to_bytes().unwrap();
let message = assert_ok(network_one.recv());
assert_eq!("123", message.peer_id());
assert_eq!(heartbeat_bytes, message.payload());
}
}