zookeeper_cache/
cache.rs

1use crate::tree::Tree;
2use crate::{ChildData, Event};
3use crate::{EventStream, Result, SharedChildData};
4use async_recursion::async_recursion;
5use futures::StreamExt;
6use futures::{stream, Stream};
7use std::collections::{HashMap, HashSet};
8use std::mem;
9use std::ops::DerefMut;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::{RwLock, RwLockWriteGuard};
13use tokio_util::sync::CancellationToken;
14use zookeeper_client::{EventType, SessionState, WatchedEvent};
15
16type Path = String;
17struct Storage {
18    data: HashMap<Path, SharedChildData>,
19    tree: Tree<Path>,
20}
21
22impl Storage {
23    pub fn new(root: String) -> Storage {
24        Storage {
25            data: HashMap::new(),
26            tree: Tree::new(root),
27        }
28    }
29
30    pub fn replace(&mut self, data: HashMap<Path, SharedChildData>, tree: Tree<Path>) -> Storage {
31        Storage {
32            data: mem::replace(&mut self.data, data),
33            tree: mem::replace(&mut self.tree, tree),
34        }
35    }
36}
37
38#[derive(Clone, Debug)]
39pub(crate) struct Version(u32, u32, u32);
40
41#[derive(Clone, Debug)]
42pub struct AuthPacket {
43    pub scheme: String,
44    pub auth: Vec<u8>,
45}
46
47/// CacheBuilder is the configuration of Cache
48#[derive(Clone, Debug)]
49pub struct CacheBuilder {
50    /// The root path which be watched
51    path: String,
52    /// The authes of Zookeeper
53    authes: Vec<AuthPacket>,
54    /// The version of Zookeeper server
55    server_version: Version,
56    /// Session timeout
57    session_timeout: Duration,
58    /// Connect timeout
59    connection_timeout: Duration,
60    /// When got session expired, we will try to  reconnect of reconnect timeout
61    reconnect_timeout: Duration,
62}
63
64impl Default for CacheBuilder {
65    fn default() -> Self {
66        Self {
67            path: "/".to_string(),
68            authes: vec![],
69            server_version: Version(u32::MAX, u32::MAX, u32::MAX),
70            session_timeout: Duration::ZERO,
71            connection_timeout: Duration::ZERO,
72            reconnect_timeout: Duration::from_secs(1),
73        }
74    }
75}
76
77impl From<&CacheBuilder> for zookeeper_client::Connector {
78    fn from(val: &CacheBuilder) -> Self {
79        let mut connector = zookeeper_client::Client::connector();
80        connector.server_version(
81            val.server_version.0,
82            val.server_version.1,
83            val.server_version.2,
84        );
85        for auth in val.authes.clone() {
86            connector.auth(auth.scheme, auth.auth);
87        }
88        connector.session_timeout(val.session_timeout);
89        connector.connection_timeout(val.connection_timeout);
90        connector.readonly(true);
91        connector
92    }
93}
94
95/// CacheBuilder cant config the Cache's configuration
96///```no_run
97/// use std::time::Duration;
98/// use zookeeper_cache::CacheBuilder;
99/// async fn dox() -> zookeeper_cache::Result<()>{
100///    let builder = CacheBuilder::new("/test")
101///                .with_version(3,9,1)
102///                .with_connect_timeout(Duration::from_secs(10))
103///                .with_session_timeout(Duration::from_secs(10))
104///                .with_reconnect_timeout(Duration::from_secs(1));
105///    let (_cache,_stream) = builder.build("localhost:2181").await?;
106///    Ok(())
107/// }
108impl CacheBuilder {
109    pub fn new(path: impl Into<String>) -> Self {
110        Self {
111            path: path.into(),
112            ..Default::default()
113        }
114    }
115
116    pub fn with_auth(mut self, scheme: String, auth: Vec<u8>) -> Self {
117        self.authes.push(AuthPacket { scheme, auth });
118        self
119    }
120
121    pub fn with_version(mut self, major: u32, minor: u32, patch: u32) -> Self {
122        self.server_version = Version(major, minor, patch);
123        self
124    }
125
126    pub fn with_session_timeout(mut self, timeout: Duration) -> Self {
127        self.session_timeout = timeout;
128        self
129    }
130
131    pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
132        self.connection_timeout = timeout;
133        self
134    }
135
136    pub fn with_reconnect_timeout(mut self, timeout: Duration) -> Self {
137        self.reconnect_timeout = timeout;
138        self
139    }
140
141    pub async fn build(
142        self,
143        addr: impl Into<String>,
144    ) -> Result<(Cache, impl Stream<Item = Event>)> {
145        Cache::new(addr, self).await
146    }
147}
148
149/// Cache will watch root node and it's children nodes recursively
150///```no_run
151/// use futures::StreamExt;
152/// use zookeeper_cache::CacheBuilder;
153/// async fn dox() -> zookeeper_cache::Result<()>{
154///    let (cache,mut stream) = CacheBuilder::default().build("localhost:2181").await?;
155///        tokio::spawn(async move{
156///            while let Some(_event) = stream.next().await{
157///                // handle event
158///            }
159///        });
160///    cache.get("/test").await;
161///    Ok(())
162/// }
163pub struct Cache {
164    addr: String,
165    builder: CacheBuilder,
166    storage: Arc<RwLock<Storage>>,
167    event_sender: tokio::sync::mpsc::UnboundedSender<Event>,
168    token: CancellationToken,
169}
170
171impl Drop for Cache {
172    fn drop(&mut self) {
173        self.token.cancel();
174    }
175}
176
177impl Cache {
178    pub async fn new(
179        addr: impl Into<String>,
180        builder: CacheBuilder,
181    ) -> Result<(Self, impl Stream<Item = Event>)> {
182        let mut connector: zookeeper_client::Connector = (&builder).into();
183        let addr = addr.into();
184        let client = connector.connect(&addr).await?;
185        let storage = Arc::new(RwLock::new(Storage::new(builder.path.clone())));
186        let (sender, watcher) = tokio::sync::mpsc::unbounded_channel();
187        let events = EventStream { watcher };
188        let cache = Self {
189            addr,
190            builder: builder.clone(),
191            storage,
192            event_sender: sender,
193            token: CancellationToken::new(),
194        };
195        let (sender, watcher) = tokio::sync::mpsc::unbounded_channel();
196        Self::init_nodes(
197            &client,
198            &builder.path,
199            cache.storage.write().await.deref_mut(),
200            &sender,
201            &cache.event_sender,
202        )
203        .await?;
204        cache.watch(client, sender, watcher).await;
205        Ok((cache, events))
206    }
207
208    /// Get data and stat through path
209    ///```no_run
210    /// use zookeeper_cache::CacheBuilder;
211    /// async fn dox()->zookeeper_cache::Result<()>{
212    ///    let (cache, _stream) = CacheBuilder::default().build("localhost:2181").await?;
213    ///    cache.get("/test").await;
214    ///    Ok(())
215    /// }
216    /// ```
217    pub async fn get(&self, path: &str) -> Option<SharedChildData> {
218        self.storage.read().await.data.get(path).cloned()
219    }
220
221    async fn init_nodes(
222        client: &zookeeper_client::Client,
223        path: &str,
224        storage: &mut Storage,
225        sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
226        event_sender: &tokio::sync::mpsc::UnboundedSender<Event>,
227    ) -> Result<()> {
228        let new = Arc::new(RwLock::new(Storage::new(path.to_string())));
229        Self::fetch_all(client, path, &mut new.write().await, sender, true).await?;
230        // send events of existed node
231        let new = new.write().await;
232        Self::compare_storage(path, storage, &new, event_sender).await;
233        storage.replace(new.data.clone(), new.tree.clone());
234        Ok(())
235    }
236
237    async fn watch(
238        &self,
239        mut client: zookeeper_client::Client,
240        sender: tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
241        mut watcher: tokio::sync::mpsc::UnboundedReceiver<WatchedEvent>,
242    ) {
243        let addr = self.addr.clone();
244        let storage = self.storage.clone();
245        let sender = sender.clone();
246        let builder = self.builder.clone();
247        let event_sender = self.event_sender.clone();
248        let token = self.token.clone();
249        tokio::spawn(async move {
250            let mut control = HandleControl::Handle;
251            loop {
252                tokio::select! {
253                    _ = token.cancelled() => {
254                        return
255                    }
256                    event = watcher.recv() => {
257                        match event{
258                            Some(event) => {
259                                match control {
260                                    HandleControl::Handle => {},
261                                    HandleControl::Continue => {
262                                        if event.event_type == EventType::Session && event.session_state.is_terminated(){
263                                            continue;
264                                        } else {
265                                            control = HandleControl::Handle;
266                                        }
267                                    }
268                                };
269                                if let Some(reconnect) = Self::handle_event(&addr, &client, &builder, &storage, event, &sender, &event_sender, &token).await{
270                                    client = reconnect;
271                                    // to ignore the other session expired events
272                                    control = HandleControl::Continue;
273                                }
274                            }
275                            None => break
276                        }
277                    }
278                }
279            }
280        });
281    }
282
283    #[allow(clippy::too_many_arguments)]
284    async fn handle_event(
285        addr: &str,
286        client: &zookeeper_client::Client,
287        builder: &CacheBuilder,
288        storage: &Arc<RwLock<Storage>>,
289        event: WatchedEvent,
290        sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
291        event_sender: &tokio::sync::mpsc::UnboundedSender<Event>,
292        token: &CancellationToken,
293    ) -> Option<zookeeper_client::Client> {
294        match event.event_type {
295            EventType::Session => {
296                if let Some(client) =
297                    Self::handle_session(addr, builder, storage, event, sender, event_sender, token)
298                        .await
299                {
300                    return Some(client);
301                }
302            }
303            EventType::NodeDeleted => {
304                Self::handle_node_deleted(storage, event, event_sender).await;
305            }
306            EventType::NodeDataChanged => {
307                Self::handle_node_data_changed(client, storage, event, sender, event_sender).await;
308            }
309            EventType::NodeChildrenChanged => {
310                Self::handle_children_changed(client, storage, event, sender, event_sender).await;
311            }
312            EventType::NodeCreated => {
313                Self::handle_node_created(client, storage, event, sender, event_sender).await;
314            }
315        }
316        None
317    }
318
319    async fn handle_session(
320        addr: &str,
321        builder: &CacheBuilder,
322        storage: &Arc<RwLock<Storage>>,
323        event: WatchedEvent,
324        sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
325        event_sender: &tokio::sync::mpsc::UnboundedSender<Event>,
326        token: &CancellationToken,
327    ) -> Option<zookeeper_client::Client> {
328        // todo add log
329        match event.session_state {
330            SessionState::Expired | SessionState::Closed => {
331                let mut interval = tokio::time::interval(builder.reconnect_timeout);
332                let mut connector: zookeeper_client::Connector = builder.into();
333                let client = loop {
334                    tokio::select! {
335                        _ = token.cancelled() => {
336                            return None
337                        }
338                        _ = interval.tick() => {
339                             match connector.connect(addr).await {
340                                Ok(zk) => break zk,
341                                Err(_err) => {
342                                }
343                            };
344                        }
345                    }
346                };
347                {
348                    loop {
349                        match Self::init_nodes(
350                            &client,
351                            &builder.path,
352                            storage.write().await.deref_mut(),
353                            sender,
354                            event_sender,
355                        )
356                        .await
357                        {
358                            Ok(_) => break,
359                            Err(_err) => {
360                                interval.tick().await;
361                            }
362                        }
363                    }
364                }
365                return Some(client);
366            }
367            _ => {}
368        };
369        None
370    }
371
372    /// only used when root node be created
373    async fn handle_node_created(
374        client: &zookeeper_client::Client,
375        storage: &Arc<RwLock<Storage>>,
376        event: WatchedEvent,
377        sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
378        event_sender: &tokio::sync::mpsc::UnboundedSender<Event>,
379    ) {
380        let mut storage = storage.write().await;
381        if let Ok(status) = Self::get_root_node(client, &event.path, &mut storage, sender).await {
382            match status {
383                RootStatus::Ephemeral(data) => {
384                    let _ = event_sender.send(Event::Add(data));
385                }
386                RootStatus::Persistent(data) => {
387                    if let Err(err) = Self::list_children(client, &event.path, sender).await {
388                        debug_assert_eq!(err, zookeeper_client::Error::NoNode);
389                    }
390                    let _ = event_sender.send(Event::Add(data));
391                }
392                _ => {}
393            }
394        }
395    }
396
397    async fn handle_node_deleted(
398        storage: &Arc<RwLock<Storage>>,
399        event: WatchedEvent,
400        event_sender: &tokio::sync::mpsc::UnboundedSender<Event>,
401    ) {
402        let mut storage = storage.write().await;
403        storage.tree.remove_child(&event.path);
404        match storage.data.get(&event.path) {
405            None => {}
406            Some(_data) => {}
407        }
408        match storage.data.remove(&event.path) {
409            None => {}
410            Some(child_data) => {
411                let _ = event_sender.send(Event::Delete(child_data));
412            }
413        }
414    }
415
416    async fn handle_node_data_changed(
417        client: &zookeeper_client::Client,
418        storage: &Arc<RwLock<Storage>>,
419        event: WatchedEvent,
420        sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
421        event_sender: &tokio::sync::mpsc::UnboundedSender<Event>,
422    ) {
423        let mut storage = storage.write().await;
424        let old = storage.data.get(&event.path).unwrap().clone();
425        if let Err(err) = Self::get_data(client, &event.path, &mut storage, sender).await {
426            debug_assert_eq!(err, zookeeper_client::Error::NoNode);
427            // data deleted
428            storage.tree.remove_child(&event.path);
429            let child_data = storage.data.remove(&event.path).unwrap();
430            let _ = event_sender.send(Event::Delete(child_data));
431            return;
432        };
433        let new = storage.data.get(&event.path).unwrap().clone();
434        let _ = event_sender.send(Event::Update { old, new });
435    }
436
437    async fn handle_children_changed(
438        client: &zookeeper_client::Client,
439        storage: &Arc<RwLock<Storage>>,
440        event: WatchedEvent,
441        sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
442        event_sender: &tokio::sync::mpsc::UnboundedSender<Event>,
443    ) {
444        let old_children = storage
445            .read()
446            .await
447            .tree
448            .children(&event.path)
449            .into_iter()
450            .map(|child| child.to_string())
451            .collect::<Vec<_>>();
452        let new_children = match Self::list_children(client, &event.path, sender).await {
453            Ok(children) => children
454                .iter()
455                .map(|child| make_path(&event.path, child))
456                .collect::<Vec<_>>(),
457            Err(err) => {
458                debug_assert_eq!(err, zookeeper_client::Error::NoNode);
459                return;
460            }
461        };
462        let (added, _) = compare(&old_children, &new_children);
463        //only handle node added
464        let added = added
465            .into_iter()
466            .map(|added| {
467                let zk = client.clone();
468                let path = event.path.clone();
469                let sender = sender.clone();
470                let event_sender = event_sender.clone();
471                (zk, storage, path, added, sender, event_sender)
472            })
473            .collect::<Vec<_>>();
474        stream::iter(added)
475            .for_each_concurrent(
476                // we fetch children through stream
477                20,
478                |(zk, storage, parent, child_path, sender, event_sender)| async move {
479                    let mut storage = storage.write().await;
480                    let child_data =
481                        match Self::get_data(&zk, &child_path, &mut storage, &sender).await {
482                            Ok(data) => data,
483                            Err(err) => {
484                                debug_assert_eq!(err, zookeeper_client::Error::NoNode);
485                                return;
486                            }
487                        };
488                    storage.tree.add_child(&parent, child_path.clone());
489                    if child_data.stat.ephemeral_owner == 0 {
490                        if let Err(err) = Self::list_children(&zk, &child_path, &sender).await {
491                            debug_assert_eq!(err, zookeeper_client::Error::NoNode);
492                        }
493                    }
494                    let _ = event_sender.send(Event::Add(child_data.clone()));
495                },
496            )
497            .await;
498    }
499
500    async fn get_data(
501        client: &zookeeper_client::Client,
502        path: &str,
503        storage: &mut RwLockWriteGuard<'_, Storage>,
504        sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
505    ) -> std::result::Result<SharedChildData, zookeeper_client::Error> {
506        let (data, stat, watcher) = client.get_and_watch_data(path).await?;
507        let data = Arc::new(ChildData {
508            path: path.to_string(),
509            data,
510            stat,
511        });
512        storage.data.insert(path.to_string(), data.clone());
513        {
514            let sender = sender.clone();
515            tokio::spawn(async move {
516                let _ = sender.send(watcher.changed().await);
517            });
518        }
519        Ok(data)
520    }
521
522    async fn list_children(
523        client: &zookeeper_client::Client,
524        path: &str,
525        sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
526    ) -> std::result::Result<Vec<String>, zookeeper_client::Error> {
527        let (children, watcher) = client.list_and_watch_children(path).await?;
528        {
529            let sender = sender.clone();
530            tokio::spawn(async move {
531                let _ = sender.send(watcher.changed().await);
532            });
533        }
534        Ok(children)
535    }
536
537    /// get the root node, if it exists return true, or return false
538    #[async_recursion]
539    async fn get_root_node(
540        client: &zookeeper_client::Client,
541        path: &str,
542        storage: &mut RwLockWriteGuard<'_, Storage>,
543        sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
544    ) -> std::result::Result<RootStatus, zookeeper_client::Error> {
545        match client.check_and_watch_stat(path).await? {
546            (None, watcher) => {
547                let sender = sender.clone();
548                tokio::spawn(async move {
549                    let _ = sender.send(watcher.changed().await);
550                });
551                Ok(RootStatus::NotExist)
552            }
553            (Some(_), _) => {
554                match Self::get_data(client, path, storage, sender).await {
555                    Ok(data) if data.stat.ephemeral_owner != 0 => {
556                        Ok(RootStatus::Ephemeral(data.clone()))
557                    }
558                    Ok(data) => Ok(RootStatus::Persistent(data.clone())),
559                    Err(err) => {
560                        debug_assert_eq!(err, zookeeper_client::Error::NoNode);
561                        // if  node exist -> node deleted, we need to repeat this function
562                        Self::get_root_node(client, path, storage, sender).await
563                    }
564                }
565            }
566        }
567    }
568
569    #[async_recursion]
570    async fn fetch_all(
571        client: &zookeeper_client::Client,
572        path: &str,
573        storage: &mut RwLockWriteGuard<Storage>,
574        sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
575        root: bool,
576    ) -> std::result::Result<(), zookeeper_client::Error> {
577        let persistent = if root {
578            matches!(
579                Self::get_root_node(client, path, storage, sender).await?,
580                RootStatus::Persistent(_)
581            )
582        } else {
583            Self::get_data(client, path, storage, sender)
584                .await?
585                .stat
586                .ephemeral_owner
587                == 0
588        };
589        if persistent {
590            let children = match Self::list_children(client, path, sender).await {
591                Ok(children) => children,
592                Err(_) => return Ok(()),
593            };
594            storage.tree.add_children(
595                path,
596                children
597                    .iter()
598                    .map(|child| make_path(path, child))
599                    .collect(),
600            );
601            for child in children.iter() {
602                if let Err(zookeeper_client::Error::NoNode) = Self::fetch_all(
603                    client,
604                    make_path(path, child).as_str(),
605                    storage,
606                    sender,
607                    false,
608                )
609                .await
610                {
611                    continue;
612                }
613            }
614        }
615        Ok(())
616    }
617
618    #[async_recursion]
619    async fn compare_storage(
620        path: &str,
621        old: &Storage,
622        new: &Storage,
623        sender: &tokio::sync::mpsc::UnboundedSender<Event>,
624    ) {
625        let old_data = old.data.get(path);
626        let new_data = new.data.get(path);
627        match (old_data, new_data) {
628            (Some(data), None) => {
629                let _ = sender.send(Event::Delete(data.clone()));
630            }
631            (None, Some(data)) => {
632                let _ = sender.send(Event::Add(data.clone()));
633            }
634            (Some(old), Some(new)) => {
635                if !old.eq(new) {
636                    let _ = sender.send(Event::Update {
637                        old: old.clone(),
638                        new: new.clone(),
639                    });
640                }
641            }
642            _ => {}
643        }
644        let mut old_children = old.tree.children(path);
645        let mut new_children = new.tree.children(path);
646        old_children.append(&mut new_children);
647        let children = old_children.into_iter().collect::<HashSet<_>>();
648        for child in children.iter() {
649            Self::compare_storage(child, old, new, sender).await;
650        }
651    }
652}
653
654fn make_path(parent: &str, child: &str) -> String {
655    if let Some('/') = parent.chars().last() {
656        format!("{}{}", parent, child)
657    } else {
658        format!("{}/{}", parent, child)
659    }
660}
661
662fn compare(old: &[String], new: &[String]) -> (Vec<String>, Vec<String>) {
663    let old_map = old.iter().collect::<HashSet<_>>();
664    let new_map = new.iter().collect::<HashSet<_>>();
665    let and = &new_map & &old_map;
666    (
667        (&new_map ^ &and)
668            .into_iter()
669            .map(|s| s.to_string())
670            .collect(),
671        (&old_map ^ &and)
672            .into_iter()
673            .map(|s| s.to_string())
674            .collect(),
675    )
676}
677
678#[derive(Clone, Debug)]
679enum RootStatus {
680    NotExist,
681    Ephemeral(SharedChildData),
682    Persistent(SharedChildData),
683}
684
685#[derive(Clone, Debug)]
686enum HandleControl {
687    Handle,
688    Continue,
689}
690
691#[cfg(test)]
692mod tests {
693    #[test]
694    fn compare() {
695        let old = ["1".to_string(), "2".to_string(), "3".to_string()];
696        let new = ["2".to_string(), "3".to_string(), "4".to_string()];
697        let (added, deleted) = super::compare(&old, &new);
698        assert_eq!(added, vec!["4".to_string()]);
699        assert_eq!(deleted, vec!["1".to_string()]);
700    }
701}