1use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
5use tokio::sync::Mutex;
6use tracing::*;
7
8use crate::listeners::ListenerSet;
9use crate::paths::make_path;
10use crate::{
11 Stat, Subscription, WatchedEvent, WatchedEventType, ZkError, ZkResult, ZkState, ZooKeeper,
12 ZooKeeperExt,
13};
14
15pub type ChildData = Arc<(Vec<u8>, Stat)>;
17
18pub type Data = HashMap<String, ChildData>;
20
21#[derive(Debug, Clone)]
22pub enum PathChildrenCacheEvent {
23 Initialized(Data),
24 ConnectionSuspended,
25 ConnectionLost,
26 ConnectionReconnected,
27 ChildRemoved(String),
28 ChildAdded(String, ChildData),
29 ChildUpdated(String, ChildData),
30}
31
32#[derive(Copy, Clone, Debug, PartialEq, Eq)]
33enum RefreshMode {
34 Standard,
35 ForceGetDataAndStat,
36}
37
38#[allow(dead_code)]
39#[derive(Debug)]
40enum Operation {
41 Initialize,
42 Shutdown,
43 Refresh(RefreshMode),
44 Event(PathChildrenCacheEvent),
45 GetData(String ),
46 ZkStateEvent(ZkState),
47}
48
49pub struct PathChildrenCache {
62 path: Arc<String>,
63 zk: Arc<ZooKeeper>,
64 data: Arc<Mutex<Data>>,
65 channel: Option<UnboundedSender<Operation>>,
66 listener_subscription: Option<Subscription>,
67 event_listeners: ListenerSet<PathChildrenCacheEvent>,
68}
69
70impl PathChildrenCache {
71 pub async fn new(zk: Arc<ZooKeeper>, path: &str) -> ZkResult<PathChildrenCache> {
77 let data = Arc::new(Mutex::new(HashMap::new()));
78
79 zk.ensure_path(path).await?;
80
81 Ok(PathChildrenCache {
82 path: Arc::new(path.to_owned()),
83 zk,
84 data,
85 channel: None,
86 listener_subscription: None,
87 event_listeners: ListenerSet::new(),
88 })
89 }
90
91 async fn get_children(
92 zk: Arc<ZooKeeper>,
93 path: &str,
94 data: Arc<Mutex<Data>>,
95 ops_chan: UnboundedSender<Operation>,
96 mode: RefreshMode,
97 ) -> ZkResult<()> {
98 let ops_chan1 = ops_chan.clone();
99
100 let watcher = move |event: WatchedEvent| {
101 match event.event_type {
102 WatchedEventType::NodeChildrenChanged => {
103 let _path = event.path.as_ref().expect("Path absent");
104
105 if let Err(err) = ops_chan1.send(Operation::Refresh(RefreshMode::Standard)) {
107 warn!("error sending Refresh operation to ops channel: {:?}", err);
108 }
109 }
110 _ => error!("Unexpected: {:?}", event),
111 };
112 };
113
114 let children = zk.get_children_w(path, watcher).await?;
115
116 let mut data_locked = data.lock().await;
117
118 for child in &children {
119 let child_path = make_path(path, child);
120
121 if mode == RefreshMode::ForceGetDataAndStat || !data_locked.contains_key(&child_path) {
122 let child_data = Arc::new(
123 Self::get_data(zk.clone(), &child_path, data.clone(), ops_chan.clone()).await?,
124 );
125
126 data_locked.insert(child_path.clone(), child_data.clone());
127
128 ops_chan
129 .send(Operation::Event(PathChildrenCacheEvent::ChildAdded(
130 child_path, child_data,
131 )))
132 .map_err(|err| {
133 info!("error sending ChildAdded event: {:?}", err);
134 ZkError::APIError
135 })?;
136 }
137 }
138
139 trace!("New data: {:?}", *data_locked);
140
141 Ok(())
142 }
143
144 async fn get_data(
145 zk: Arc<ZooKeeper>,
146 path: &str,
147 data: Arc<Mutex<Data>>,
148 ops_chan: UnboundedSender<Operation>,
149 ) -> ZkResult<(Vec<u8>, Stat)> {
150 let path1 = path.to_owned();
151
152 let data_watcher = move |event: WatchedEvent| {
153 let data = data.clone();
154 let ops_chan = ops_chan.clone();
155 let path1 = path1.clone();
156
157 tokio::spawn(async move {
158 let mut data_locked = data.lock().await;
159 match event.event_type {
160 WatchedEventType::NodeDeleted => {
161 data_locked.remove(&path1);
162
163 if let Err(err) = ops_chan.send(Operation::Event(
164 PathChildrenCacheEvent::ChildRemoved(path1.clone()),
165 )) {
166 warn!("error sending ChildRemoved event: {:?}", err);
167 }
168 }
169 WatchedEventType::NodeDataChanged => {
170 if let Err(err) = ops_chan.send(Operation::GetData(path1.clone())) {
172 warn!("error sending GetData to op channel: {:?}", err);
173 }
174 }
175 _ => error!("Unexpected: {:?}", event),
176 };
177
178 trace!("New data: {:?}", *data_locked);
179 });
180 };
181
182 zk.get_data_w(path, data_watcher).await
183 }
184
185 async fn update_data(
186 zk: Arc<ZooKeeper>,
187 path: &str,
188 data: Arc<Mutex<Data>>,
189 ops_chan_tx: UnboundedSender<Operation>,
190 ) -> ZkResult<()> {
191 let mut data_locked = data.lock().await;
192
193 let path = path.to_owned();
194
195 let result = Self::get_data(zk.clone(), &path, data.clone(), ops_chan_tx.clone()).await;
196
197 match result {
198 Ok(child_data) => {
199 trace!("got data {:?}", child_data);
200
201 let child_data = Arc::new(child_data);
202
203 data_locked.insert(path.clone(), child_data.clone());
204
205 ops_chan_tx
206 .send(Operation::Event(PathChildrenCacheEvent::ChildUpdated(
207 path, child_data,
208 )))
209 .map_err(|err| {
210 warn!("error sending ChildUpdated event: {:?}", err);
211 ZkError::APIError
212 })
213 }
214 Err(err) => {
215 warn!("error getting child data: {:?}", err);
216 Err(ZkError::APIError)
217 }
218 }
219 }
220
221 pub async fn get_current_data(&self) -> Data {
224 self.data.lock().await.clone()
225 }
226
227 pub async fn clear(&self) {
228 self.data.lock().await.clear()
229 }
230
231 fn handle_state_change(state: ZkState, ops_chan_tx: UnboundedSender<Operation>) -> bool {
232 let mut done = false;
233
234 debug!("zk state change {:?}", state);
235 if let ZkState::Connected = state {
236 if let Err(err) = ops_chan_tx.send(Operation::Refresh(RefreshMode::ForceGetDataAndStat))
237 {
238 warn!("error sending Refresh to op channel: {:?}", err);
239 done = true;
240 }
241 }
242
243 done
244 }
245
246 async fn handle_operation(
247 op: Operation,
248 zk: Arc<ZooKeeper>,
249 path: Arc<String>,
250 data: Arc<Mutex<Data>>,
251 event_listeners: ListenerSet<PathChildrenCacheEvent>,
252 ops_chan_tx: UnboundedSender<Operation>,
253 ) -> bool {
254 let mut done = false;
255
256 match op {
257 Operation::Initialize => {
258 debug!("initialising...");
259 let result = Self::get_children(
260 zk.clone(),
261 &path,
262 data.clone(),
263 ops_chan_tx.clone(),
264 RefreshMode::ForceGetDataAndStat,
265 )
266 .await;
267 debug!("got children {:?}", result);
268
269 event_listeners.notify(&PathChildrenCacheEvent::Initialized(
270 data.lock().await.clone(),
271 ));
272 }
273 Operation::Shutdown => {
274 debug!("shutting down worker thread");
275 done = true;
276 }
277 Operation::Refresh(mode) => {
278 debug!("getting children");
279 let result =
280 Self::get_children(zk.clone(), &path, data.clone(), ops_chan_tx.clone(), mode)
281 .await;
282 debug!("got children {:?}", result);
283 }
284 Operation::GetData(path) => {
285 debug!("getting data");
286 let result =
287 Self::update_data(zk.clone(), &path, data.clone(), ops_chan_tx.clone()).await;
288 if let Err(err) = result {
289 warn!("error getting child data: {:?}", err);
290 }
291 }
292 Operation::Event(event) => {
293 debug!("received event {:?}", event);
294 event_listeners.notify(&event);
295 }
296 Operation::ZkStateEvent(state) => {
297 done = Self::handle_state_change(state, ops_chan_tx.clone());
298 }
299 }
300
301 done
302 }
303
304 pub fn start(&mut self) -> ZkResult<()> {
306 let (ops_chan_tx, mut ops_chan_rx) = unbounded_channel();
307 let ops_chan_rx_zk_events = ops_chan_tx.clone();
308
309 let sub = self.zk.add_listener(move |s| {
310 ops_chan_rx_zk_events
311 .send(Operation::ZkStateEvent(s))
312 .unwrap()
313 });
314 self.listener_subscription = Some(sub);
315
316 let zk = self.zk.clone();
317 let path = self.path.clone();
318 let data = self.data.clone();
319 let event_listeners = self.event_listeners.clone();
320 self.channel = Some(ops_chan_tx.clone());
321
322 tokio::spawn(async move {
323 let mut done = false;
324
325 while !done {
326 match ops_chan_rx.recv().await {
327 Some(operation) => {
328 done = Self::handle_operation(
329 operation,
330 zk.clone(),
331 path.clone(),
332 data.clone(),
333 event_listeners.clone(),
334 ops_chan_tx.clone(),
335 )
336 .await;
337 }
338 None => {
339 info!("error receiving from operations channel. shutting down");
340 done = true;
341 }
342 }
343 }
344 });
345
346 self.offer_operation(Operation::Initialize)
347 }
348
349 pub fn add_listener<Listener: Fn(PathChildrenCacheEvent) + Send + 'static>(
350 &self,
351 subscriber: Listener,
352 ) -> Subscription {
353 self.event_listeners.subscribe(subscriber)
354 }
355
356 pub fn remove_listener(&self, sub: Subscription) {
357 self.event_listeners.unsubscribe(sub)
358 }
359
360 fn offer_operation(&self, op: Operation) -> ZkResult<()> {
361 match self.channel {
362 Some(ref chan) => chan.send(op).map_err(|err| {
363 warn!("error submitting op to channel: {:?}", err);
364 ZkError::APIError
365 }),
366 None => Err(ZkError::APIError),
367 }
368 }
369}