use zookeeper::{ZooKeeper, WatchedEvent, WatchedEventType, ZkError, ZkResult};
use std::sync::{Arc, RwLock, LockResult};
use serde::{Serialize};
use serde::de::DeserializeOwned;
use treediff::{value::Key, diff, tools::ChangeType};
use std::time::{Instant, Duration};
use std::thread;
use anyhow::Context;
use std::sync::RwLockReadGuard;
const LOCK_POLL_INTERVAL: u64 = 5;
const LOCK_POLL_TIMEOUT: u64 = 1000;
pub enum ZkStructError {
LockAcquireTimeout,
StaleWrite,
ZkError(ZkError)
}
struct InternalState {
zk_path: String,
epoch: i32,
timings: (chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>),
emit_updates: bool,
chan_rx: crossbeam_channel::Receiver<Change<Key, serde_json::Value>>,
chan_tx: crossbeam_channel::Sender<Change<Key, serde_json::Value>>,
}
#[derive(Clone)]
pub struct ZkState<T: Serialize + DeserializeOwned + Send + Sync> {
zk: Arc<ZooKeeper>,
id: String,
inner: Arc<RwLock<T>>,
state: Arc<RwLock<InternalState>>
}
impl<T: Serialize + DeserializeOwned + Send + Sync + 'static> ZkState<T> {
pub fn new(zk: Arc<ZooKeeper>, zk_path: String, initial_state: T) -> anyhow::Result<Self> {
let instance_id = uuid::Uuid::new_v4();
let (chan_tx, chan_rx) = crossbeam_channel::unbounded();
let r = Self {
id: instance_id.to_string(),
zk,
inner: Arc::new(RwLock::new(initial_state)),
state: Arc::new(RwLock::new(InternalState {
zk_path,
epoch: 0,
timings: (chrono::Utc::now(), chrono::Utc::now(), chrono::Utc::now()),
emit_updates: true,
chan_rx,
chan_tx
}))
};
r.initialize()?;
Ok(r)
}
fn initialize(&self) -> ZkResult<()> {
let path = format!("{}/payload", &self.state.read().unwrap().zk_path);
if self.zk.exists(path.as_str(), false).unwrap().is_none() {
log::debug!("{} does not exist, creating", &path);
self.zk.create(&self.state.read().unwrap().zk_path, vec![], zookeeper::Acl::open_unsafe().clone(), zookeeper::CreateMode::Persistent)?;
let data = self.inner.read().unwrap();
let inner = serde_json::to_vec(&*data).unwrap();
self.zk.create(path.as_str(), inner, zookeeper::Acl::open_unsafe().clone(), zookeeper::CreateMode::Persistent)?;
}
log::debug!("{} exists, continuing initialization", &path);
state_change(self.zk.clone(), self.inner.clone(), self.state.clone());
thread::spawn(|| {
});
Ok(())
}
pub fn update<M: FnOnce(&mut T) -> ()>(self, closure: M) -> Result<(), ZkStructError> {
let path = format!("{}/payload", &self.state.read().unwrap().zk_path);
let mut inner = self.inner.write().unwrap();
let mut state = self.state.write().unwrap();
let latch_path = format!("{}/write_lock", &state.zk_path);
let latch = zookeeper::recipes::leader::LeaderLatch::new(self.zk.clone(), self.id.clone(), latch_path);
latch.start();
let mut total_time = 0;
loop {
if latch.has_leadership() { break; }
thread::sleep(Duration::from_millis(LOCK_POLL_INTERVAL));
if total_time > LOCK_POLL_TIMEOUT {
return Err(ZkStructError::LockAcquireTimeout)
} else {
total_time += LOCK_POLL_INTERVAL;
}
}
let a = serde_json::to_value(&*inner).unwrap();
closure(&mut inner);
let b = serde_json::to_value(&*inner).unwrap();
emit_updates(&a, &b, &state);
let raw_data = serde_json::to_vec(&*inner).unwrap();
let update_op = self.zk.set_data(path.as_str(), raw_data, Some(state.epoch));
match update_op {
Ok(_) => {}
Err(err) => {
if err == ZkError::BadVersion {
return Err(ZkStructError::StaleWrite)
}
return Err(ZkStructError::ZkError(err))
}
}
drop(inner);
drop(state);
Ok(())
}
pub fn read(&self) -> LockResult<RwLockReadGuard<'_, T>> {
self.inner.read()
}
pub fn metadata(&self) -> (usize, i32) {
return (self.state.read().unwrap().chan_rx.len(), 0)
}
}
fn handle_zk_watcher<'a, T: Serialize + DeserializeOwned + Send + Sync + 'static>(ev: WatchedEvent, zk: Arc<ZooKeeper>, inner: Arc<RwLock<T>>, state: Arc<RwLock<InternalState>>) {
match ev.event_type {
WatchedEventType::NodeDataChanged => state_change(zk, inner, state),
_ => {}
}
}
#[derive(PartialEq)]
pub enum Change<K, V> {
Removed(Vec<K>, V),
Added(Vec<K>, V),
Unchanged(),
Modified(Vec<K>, V, V),
}
fn state_change<'a, T: Serialize + DeserializeOwned + Send + Sync + 'static>(zk: Arc<ZooKeeper>, inner: Arc<RwLock<T>>, state: Arc<RwLock<InternalState>>) {
let start = Instant::now();
let path = format!("{}/payload", &state.read().unwrap().zk_path);
let movable = (zk.clone(), inner.clone(), state.clone());
let raw_obj = zk.get_data_w(path.as_str(), move |ev| {
let movable = movable.clone();
handle_zk_watcher(ev, movable.0, movable.1, movable.2);
}).unwrap();
let mut a_handle = inner.write().unwrap();
let mut state = state.write().unwrap();
let b: serde_json::Value = serde_json::from_slice(&*raw_obj.0).unwrap();
if state.emit_updates {
let a: serde_json::Value = serde_json::to_value(&*a_handle).unwrap();
emit_updates(&a, &b, &state);
}
*a_handle = serde_json::from_value(b).unwrap();
state.epoch = raw_obj.1.version;
drop(a_handle);
drop(state);
log::debug!("took {}ms to handle state change", start.elapsed().as_millis());
}
fn emit_updates(a: &serde_json::Value, b: &serde_json::Value, state: &InternalState) {
let mut delta = treediff::tools::Recorder::default();
diff(a, b, &mut delta);
let mut ops = (0, 0, 0, 0);
for change in delta.calls {
let op = match change {
ChangeType::Added(k, v) => {
ops.0 += 1;
Change::Added(k.clone(), v.clone())
},
ChangeType::Removed(k, v) => {
ops.1 += 1;
Change::Removed(k.clone(), v.clone())
},
ChangeType::Modified(k, a, v) => {
ops.2 += 1;
Change::Modified(k.clone(), a.clone(), v.clone())
},
ChangeType::Unchanged(_, _) => {
ops.3 += 1;
Change::Unchanged()
},
};
if op != Change::Unchanged() {
let _insert = state.chan_tx.send(op);
}
}
log::debug!("{} added, {} removed, {} modified, {} noop", ops.0, ops.1, ops.2, ops.3);
}