use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::Mutex;
use tracing::*;
use crate::listeners::ListenerSet;
use crate::paths::make_path;
use crate::{
Stat, Subscription, WatchedEvent, WatchedEventType, ZkError, ZkResult, ZkState, ZooKeeper,
ZooKeeperExt,
};
pub type ChildData = Arc<(Vec<u8>, Stat)>;
pub type Data = HashMap<String, ChildData>;
#[derive(Debug, Clone)]
pub enum PathChildrenCacheEvent {
Initialized(Data),
ConnectionSuspended,
ConnectionLost,
ConnectionReconnected,
ChildRemoved(String),
ChildAdded(String, ChildData),
ChildUpdated(String, ChildData),
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
enum RefreshMode {
Standard,
ForceGetDataAndStat,
}
#[allow(dead_code)]
#[derive(Debug)]
enum Operation {
Initialize,
Shutdown,
Refresh(RefreshMode),
Event(PathChildrenCacheEvent),
GetData(String ),
ZkStateEvent(ZkState),
}
pub struct PathChildrenCache {
path: Arc<String>,
zk: Arc<ZooKeeper>,
data: Arc<Mutex<Data>>,
channel: Option<UnboundedSender<Operation>>,
listener_subscription: Option<Subscription>,
event_listeners: ListenerSet<PathChildrenCacheEvent>,
}
impl PathChildrenCache {
pub async fn new(zk: Arc<ZooKeeper>, path: &str) -> ZkResult<PathChildrenCache> {
let data = Arc::new(Mutex::new(HashMap::new()));
zk.ensure_path(path).await?;
Ok(PathChildrenCache {
path: Arc::new(path.to_owned()),
zk,
data,
channel: None,
listener_subscription: None,
event_listeners: ListenerSet::new(),
})
}
async fn get_children(
zk: Arc<ZooKeeper>,
path: &str,
data: Arc<Mutex<Data>>,
ops_chan: UnboundedSender<Operation>,
mode: RefreshMode,
) -> ZkResult<()> {
let ops_chan1 = ops_chan.clone();
let watcher = move |event: WatchedEvent| {
match event.event_type {
WatchedEventType::NodeChildrenChanged => {
let _path = event.path.as_ref().expect("Path absent");
if let Err(err) = ops_chan1.send(Operation::Refresh(RefreshMode::Standard)) {
warn!("error sending Refresh operation to ops channel: {:?}", err);
}
}
_ => error!("Unexpected: {:?}", event),
};
};
let children = zk.get_children_w(path, watcher).await?;
let mut data_locked = data.lock().await;
for child in &children {
let child_path = make_path(path, child);
if mode == RefreshMode::ForceGetDataAndStat || !data_locked.contains_key(&child_path) {
let child_data = Arc::new(
Self::get_data(zk.clone(), &child_path, data.clone(), ops_chan.clone()).await?,
);
data_locked.insert(child_path.clone(), child_data.clone());
ops_chan
.send(Operation::Event(PathChildrenCacheEvent::ChildAdded(
child_path, child_data,
)))
.map_err(|err| {
info!("error sending ChildAdded event: {:?}", err);
ZkError::APIError
})?;
}
}
trace!("New data: {:?}", *data_locked);
Ok(())
}
async fn get_data(
zk: Arc<ZooKeeper>,
path: &str,
data: Arc<Mutex<Data>>,
ops_chan: UnboundedSender<Operation>,
) -> ZkResult<(Vec<u8>, Stat)> {
let path1 = path.to_owned();
let data_watcher = move |event: WatchedEvent| {
let data = data.clone();
let ops_chan = ops_chan.clone();
let path1 = path1.clone();
tokio::spawn(async move {
let mut data_locked = data.lock().await;
match event.event_type {
WatchedEventType::NodeDeleted => {
data_locked.remove(&path1);
if let Err(err) = ops_chan.send(Operation::Event(
PathChildrenCacheEvent::ChildRemoved(path1.clone()),
)) {
warn!("error sending ChildRemoved event: {:?}", err);
}
}
WatchedEventType::NodeDataChanged => {
if let Err(err) = ops_chan.send(Operation::GetData(path1.clone())) {
warn!("error sending GetData to op channel: {:?}", err);
}
}
_ => error!("Unexpected: {:?}", event),
};
trace!("New data: {:?}", *data_locked);
});
};
zk.get_data_w(path, data_watcher).await
}
async fn update_data(
zk: Arc<ZooKeeper>,
path: &str,
data: Arc<Mutex<Data>>,
ops_chan_tx: UnboundedSender<Operation>,
) -> ZkResult<()> {
let mut data_locked = data.lock().await;
let path = path.to_owned();
let result = Self::get_data(zk.clone(), &path, data.clone(), ops_chan_tx.clone()).await;
match result {
Ok(child_data) => {
trace!("got data {:?}", child_data);
let child_data = Arc::new(child_data);
data_locked.insert(path.clone(), child_data.clone());
ops_chan_tx
.send(Operation::Event(PathChildrenCacheEvent::ChildUpdated(
path, child_data,
)))
.map_err(|err| {
warn!("error sending ChildUpdated event: {:?}", err);
ZkError::APIError
})
}
Err(err) => {
warn!("error getting child data: {:?}", err);
Err(ZkError::APIError)
}
}
}
pub async fn get_current_data(&self) -> Data {
self.data.lock().await.clone()
}
pub async fn clear(&self) {
self.data.lock().await.clear()
}
fn handle_state_change(state: ZkState, ops_chan_tx: UnboundedSender<Operation>) -> bool {
let mut done = false;
debug!("zk state change {:?}", state);
if let ZkState::Connected = state {
if let Err(err) = ops_chan_tx.send(Operation::Refresh(RefreshMode::ForceGetDataAndStat))
{
warn!("error sending Refresh to op channel: {:?}", err);
done = true;
}
}
done
}
async fn handle_operation(
op: Operation,
zk: Arc<ZooKeeper>,
path: Arc<String>,
data: Arc<Mutex<Data>>,
event_listeners: ListenerSet<PathChildrenCacheEvent>,
ops_chan_tx: UnboundedSender<Operation>,
) -> bool {
let mut done = false;
match op {
Operation::Initialize => {
debug!("initialising...");
let result = Self::get_children(
zk.clone(),
&path,
data.clone(),
ops_chan_tx.clone(),
RefreshMode::ForceGetDataAndStat,
)
.await;
debug!("got children {:?}", result);
event_listeners.notify(&PathChildrenCacheEvent::Initialized(
data.lock().await.clone(),
));
}
Operation::Shutdown => {
debug!("shutting down worker thread");
done = true;
}
Operation::Refresh(mode) => {
debug!("getting children");
let result =
Self::get_children(zk.clone(), &path, data.clone(), ops_chan_tx.clone(), mode)
.await;
debug!("got children {:?}", result);
}
Operation::GetData(path) => {
debug!("getting data");
let result =
Self::update_data(zk.clone(), &path, data.clone(), ops_chan_tx.clone()).await;
if let Err(err) = result {
warn!("error getting child data: {:?}", err);
}
}
Operation::Event(event) => {
debug!("received event {:?}", event);
event_listeners.notify(&event);
}
Operation::ZkStateEvent(state) => {
done = Self::handle_state_change(state, ops_chan_tx.clone());
}
}
done
}
pub fn start(&mut self) -> ZkResult<()> {
let (ops_chan_tx, mut ops_chan_rx) = unbounded_channel();
let ops_chan_rx_zk_events = ops_chan_tx.clone();
let sub = self.zk.add_listener(move |s| {
ops_chan_rx_zk_events
.send(Operation::ZkStateEvent(s))
.unwrap()
});
self.listener_subscription = Some(sub);
let zk = self.zk.clone();
let path = self.path.clone();
let data = self.data.clone();
let event_listeners = self.event_listeners.clone();
self.channel = Some(ops_chan_tx.clone());
tokio::spawn(async move {
let mut done = false;
while !done {
match ops_chan_rx.recv().await {
Some(operation) => {
done = Self::handle_operation(
operation,
zk.clone(),
path.clone(),
data.clone(),
event_listeners.clone(),
ops_chan_tx.clone(),
)
.await;
}
None => {
info!("error receiving from operations channel. shutting down");
done = true;
}
}
}
});
self.offer_operation(Operation::Initialize)
}
pub fn add_listener<Listener: Fn(PathChildrenCacheEvent) + Send + 'static>(
&self,
subscriber: Listener,
) -> Subscription {
self.event_listeners.subscribe(subscriber)
}
pub fn remove_listener(&self, sub: Subscription) {
self.event_listeners.unsubscribe(sub)
}
fn offer_operation(&self, op: Operation) -> ZkResult<()> {
match self.channel {
Some(ref chan) => chan.send(op).map_err(|err| {
warn!("error submitting op to channel: {:?}", err);
ZkError::APIError
}),
None => Err(ZkError::APIError),
}
}
}