use zookeeper::{ZooKeeper, WatchedEvent, WatchedEventType, ZkError, ZkResult};
use std::sync::{Arc, RwLock};
use serde::{Serialize};
use serde::de::DeserializeOwned;
pub use treediff::{value::Key, diff, tools::ChangeType};
use std::time::{Instant, Duration};
use std::thread;
use std::sync::RwLockReadGuard;
use serde_json::Value;
use crossbeam_channel::Receiver;
const MAX_TIMING_DELTA: i64 = 30000;
const LOCK_POLL_INTERVAL: u64 = 5;
const LOCK_POLL_TIMEOUT: u64 = 1000;
#[derive(Debug)]
pub enum ZkStructError {
LockAcquireTimeout,
StaleRead,
StaleWrite,
ZkError(ZkError),
Poisoned
}
impl From<ZkError> for ZkStructError {
fn from(error: ZkError) -> ZkStructError {
ZkStructError::ZkError(error)
}
}
struct InternalState {
zk_path: String,
epoch: i32,
timings: chrono::DateTime<chrono::Utc>,
emit_updates: bool,
chan_rx: crossbeam_channel::Receiver<Change<Key, serde_json::Value>>,
chan_tx: crossbeam_channel::Sender<Change<Key, serde_json::Value>>,
}
#[derive(Clone)]
pub struct ZkState<T: Serialize + DeserializeOwned + Send + Sync> {
zk: Arc<ZooKeeper>,
id: String,
inner: Arc<RwLock<T>>,
state: Arc<RwLock<InternalState>>
}
impl<T: Serialize + DeserializeOwned + Send + Sync + 'static> ZkState<T> {
pub fn new(zk: Arc<ZooKeeper>, zk_path: String, initial_state: T) -> anyhow::Result<Self> {
let instance_id = uuid::Uuid::new_v4();
log::debug!("starting zkstate");
let (chan_tx, chan_rx) = crossbeam_channel::unbounded();
let r = Self {
id: instance_id.to_string(),
zk,
inner: Arc::new(RwLock::new(initial_state)),
state: Arc::new(RwLock::new(InternalState {
zk_path,
epoch: 0,
timings: chrono::Utc::now(),
emit_updates: true,
chan_rx,
chan_tx
}))
};
r.initialize()?;
Ok(r)
}
pub fn expect(zk: Arc<ZooKeeper>, zk_path: String) -> anyhow::Result<Self> {
let instance_id = uuid::Uuid::new_v4();
log::debug!("starting zkstate");
let (l_tx, l_rx) = crossbeam_channel::unbounded();
let raw_data = zk.get_data_w(format!("{}/payload", &zk_path).as_str(), move |_| {
let _ = l_tx.send(());
});
let mut data = vec![];
if let Ok(inner) = raw_data {
data = inner.0;
} else {
let _ = l_rx.recv();
data = zk.get_data(format!("{}/payload", &zk_path).as_str(), false)?.0;
}
let (chan_tx, chan_rx) = crossbeam_channel::unbounded();
let r = Self {
id: instance_id.to_string(),
zk,
inner: Arc::new(RwLock::new(serde_json::from_slice(data.as_slice()).unwrap())),
state: Arc::new(RwLock::new(InternalState {
zk_path,
epoch: 0,
timings: chrono::Utc::now(),
emit_updates: true,
chan_rx,
chan_tx
}))
};
r.initialize()?;
Ok(r)
}
fn initialize(&self) -> ZkResult<()> {
let path = format!("{}/payload", &self.state.read().unwrap().zk_path);
if self.zk.exists(path.as_str(), false).unwrap().is_none() {
log::debug!("{} does not exist, creating", &path);
self.zk.create(&self.state.read().unwrap().zk_path, vec![], zookeeper::Acl::open_unsafe().clone(), zookeeper::CreateMode::Persistent)?;
let data = self.inner.read().unwrap();
let inner = serde_json::to_vec(&*data).unwrap();
self.zk.create(path.as_str(), inner, zookeeper::Acl::open_unsafe().clone(), zookeeper::CreateMode::Persistent)?;
}
log::debug!("{} exists, continuing initialization", &path);
state_change(self.zk.clone(), self.inner.clone(), self.state.clone());
let zk = self.zk.clone();
let state = self.state.clone();
let inner = self.inner.clone();
thread::spawn(move || {
let zk = zk;
let inner = inner;
let state = state;
loop {
thread::sleep(Duration::from_secs(5));
let handle = state.read().unwrap();
let path = format!("{}/payload", handle.zk_path);
if let Some(meta) = zk.exists(path.as_str(), false).unwrap() {
if handle.epoch != meta.version {
log::warn!("the remote epoch has drifted. local: {}. remote: {}", handle.epoch, meta.version);
state_change(zk.clone(), inner.clone(), state.clone());
}
drop(handle);
state.write().unwrap().timings = chrono::Utc::now();
}
}
});
Ok(())
}
pub fn update_handler<M: Fn(Change<Key, serde_json::Value>) + Send + 'static>(&self, closure: M) -> Result<(), ZkStructError> {
let chan_handle = self.state.read().unwrap().chan_rx.clone();
thread::spawn(move || {
let rx = chan_handle;
loop {
let message = rx.recv().unwrap();
closure(message);
}
});
Ok(())
}
pub fn get_update_channel(&self) -> Receiver<Change<Key, Value>> {
self.state.read().unwrap().chan_rx.clone()
}
pub fn update<M: FnOnce(&mut T)>(&self, closure: M) -> Result<(), ZkStructError> {
let path = format!("{}/payload", &self.state.read().unwrap().zk_path);
let mut inner = self.inner.write().unwrap();
let state = self.state.write().unwrap();
let latch_path = format!("{}/write_lock", &state.zk_path);
let latch = zookeeper::recipes::leader::LeaderLatch::new(self.zk.clone(), self.id.clone(), latch_path);
latch.start()?;
let mut total_time = 0;
loop {
if latch.has_leadership() { break; }
thread::sleep(Duration::from_millis(LOCK_POLL_INTERVAL));
if total_time > LOCK_POLL_TIMEOUT {
return Err(ZkStructError::LockAcquireTimeout)
} else {
total_time += LOCK_POLL_INTERVAL;
}
}
let a = serde_json::to_value(&*inner).unwrap();
closure(&mut inner);
let b = serde_json::to_value(&*inner).unwrap();
emit_updates(&a, &b, &state);
let raw_data = serde_json::to_vec(&*inner).unwrap();
let update_op = self.zk.set_data(path.as_str(), raw_data, Some(state.epoch));
match update_op {
Ok(_) => {}
Err(err) => {
if err == ZkError::BadVersion {
return Err(ZkStructError::StaleWrite)
}
return Err(ZkStructError::ZkError(err))
}
}
drop(inner);
drop(state);
if let Err(inner) = latch.stop() {
return Err(ZkStructError::ZkError(inner));
}
Ok(())
}
pub fn read(&self) -> Result<RwLockReadGuard<'_, T>, ZkStructError> {
let state = self.state.read().unwrap();
let delta = (chrono::Utc::now() - state.timings).num_milliseconds();
if delta > MAX_TIMING_DELTA {
log::error!("attempted to read stale data. data is {}ms old, limit is {}ms", &delta, MAX_TIMING_DELTA);
return Err(ZkStructError::StaleRead)
}
log::debug!("reading internal data. data is {}ms old, limit is {}ms. epoch is {}", &delta, MAX_TIMING_DELTA, state.epoch);
match self.inner.read() {
Ok(inner) => Ok(inner),
Err(_) => Err(ZkStructError::Poisoned)
}
}
pub fn c_read(&self) { unimplemented!() }
pub fn d_read(&self) -> Result<RwLockReadGuard<'_, T>, ZkStructError> {
let state = self.state.read().unwrap();
let delta = (chrono::Utc::now() - state.timings).num_milliseconds();
if delta > MAX_TIMING_DELTA {
log::warn!("attempted to read stale data. data is {}ms old, limit is {}ms", &delta, MAX_TIMING_DELTA);
} else {
log::debug!("dirty reading internal data. data is {}ms old, limit is {}ms. epoch is {}", &delta, MAX_TIMING_DELTA, state.epoch);
}
match self.inner.read() {
Ok(inner) => Ok(inner),
Err(_) => Err(ZkStructError::Poisoned)
}
}
pub fn metadata(&self) -> (usize, i32) {
return (self.state.read().unwrap().chan_rx.len(), 0)
}
pub fn get_id(&self) -> &String {
&self.id
}
}
fn handle_zk_watcher<T: Serialize + DeserializeOwned + Send + Sync + 'static>(ev: WatchedEvent, zk: Arc<ZooKeeper>, inner: Arc<RwLock<T>>, state: Arc<RwLock<InternalState>>) {
if let WatchedEventType::NodeDataChanged = ev.event_type {
state_change(zk, inner, state)
}
}
#[derive(PartialEq, Debug)]
pub enum Change<K, V> {
Removed(Vec<K>, V),
Added(Vec<K>, V),
Unchanged(),
Modified(Vec<K>, V, V),
}
fn state_change<T: Serialize + DeserializeOwned + Send + Sync + 'static>(zk: Arc<ZooKeeper>, inner: Arc<RwLock<T>>, state: Arc<RwLock<InternalState>>) {
let start = Instant::now();
let path = format!("{}/payload", &state.read().unwrap().zk_path);
let movable = (zk.clone(), inner.clone(), state.clone());
let raw_obj = zk.get_data_w(path.as_str(), move |ev| {
let movable = movable.clone();
handle_zk_watcher(ev, movable.0, movable.1, movable.2);
}).unwrap();
let mut a_handle = inner.write().unwrap();
let mut state = state.write().unwrap();
let b: serde_json::Value = serde_json::from_slice(&*raw_obj.0).unwrap();
if state.emit_updates {
let a: serde_json::Value = serde_json::to_value(&*a_handle).unwrap();
emit_updates(&a, &b, &state);
}
*a_handle = serde_json::from_value(b).unwrap();
state.epoch = raw_obj.1.version;
state.timings = chrono::Utc::now();
drop(a_handle);
drop(state);
log::debug!("took {}ms to handle state change", start.elapsed().as_millis());
}
fn emit_updates(a: &serde_json::Value, b: &serde_json::Value, state: &InternalState) {
let mut delta = treediff::tools::Recorder::default();
diff(a, b, &mut delta);
let mut ops = (0, 0, 0, 0);
for change in delta.calls {
let op = match change {
ChangeType::Added(k, v) => {
ops.0 += 1;
Change::Added(k.clone(), v.clone())
},
ChangeType::Removed(k, v) => {
ops.1 += 1;
Change::Removed(k.clone(), v.clone())
},
ChangeType::Modified(k, a, v) => {
ops.2 += 1;
Change::Modified(k.clone(), a.clone(), v.clone())
},
ChangeType::Unchanged(_, _) => {
ops.3 += 1;
Change::Unchanged()
},
};
if op != Change::Unchanged() {
let _insert = state.chan_tx.send(op);
}
}
log::debug!("{} added, {} removed, {} modified, {} noop", ops.0, ops.1, ops.2, ops.3);
}