zookeeper_cache_rust/
cache.rs

1use crate::tree::Tree;
2use crate::{ChildData, Event};
3use crate::{EventStream, Result, SharedChildData};
4use async_recursion::async_recursion;
5use futures::StreamExt;
6use futures::{stream, Stream};
7use std::collections::{HashMap, HashSet};
8use std::mem;
9use std::ops::DerefMut;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::{RwLock, RwLockWriteGuard};
13use tokio_util::sync::CancellationToken;
14use zookeeper_client::{EventType, SessionState, WatchedEvent};
15
16type Path = String;
17struct Storage {
18    data: HashMap<Path, SharedChildData>,
19    tree: Tree<Path>,
20}
21
22impl Storage {
23    pub fn new(root: String) -> Storage {
24        Storage {
25            data: HashMap::new(),
26            tree: Tree::new(root),
27        }
28    }
29
30    pub fn replace(&mut self, data: HashMap<Path, SharedChildData>, tree: Tree<Path>) -> Storage {
31        Storage {
32            data: mem::replace(&mut self.data, data),
33            tree: mem::replace(&mut self.tree, tree),
34        }
35    }
36}
37
38#[derive(Clone, Debug)]
39pub(crate) struct Version(u32, u32, u32);
40
41#[derive(Clone, Debug)]
42pub struct AuthPacket {
43    pub scheme: String,
44    pub auth: Vec<u8>,
45}
46
47/// CacheBuilder is the configuration of Cache
48#[derive(Clone, Debug)]
49pub struct CacheBuilder {
50    /// The root path which be watched
51    path: String,
52    /// The authes of Zookeeper
53    authes: Vec<AuthPacket>,
54    /// The version of Zookeeper server
55    server_version: Version,
56    /// Session timeout
57    session_timeout: Duration,
58    /// Connect timeout
59    connection_timeout: Duration,
60    /// When got session expired, we will try to  reconnect of reconnect timeout
61    reconnect_timeout: Duration,
62}
63
64impl Default for CacheBuilder {
65    fn default() -> Self {
66        Self {
67            path: "/".to_string(),
68            authes: vec![],
69            server_version: Version(u32::MAX, u32::MAX, u32::MAX),
70            session_timeout: Duration::ZERO,
71            connection_timeout: Duration::ZERO,
72            reconnect_timeout: Duration::from_secs(1),
73        }
74    }
75}
76
77impl From<&CacheBuilder> for zookeeper_client::Connector {
78    fn from(val: &CacheBuilder) -> Self {
79        let mut connector = zookeeper_client::Client::connector();
80        connector.server_version(
81            val.server_version.0,
82            val.server_version.1,
83            val.server_version.2,
84        );
85        for auth in val.authes.clone() {
86            connector.auth(auth.scheme, auth.auth);
87        }
88        connector.session_timeout(val.session_timeout);
89        connector.connection_timeout(val.connection_timeout);
90        connector.readonly(true);
91        connector
92    }
93}
94
95/// CacheBuilder cant config the Cache's configuration
96///```no_run
97/// use std::time::Duration;
98/// use zookeeper_cache_rust::CacheBuilder;
99/// async fn dox() -> zookeeper_cache_rust::Result<()>{
100///    let builder = CacheBuilder::new("/test")
101///                .with_version(3,9,1)
102///                .with_connect_timeout(Duration::from_secs(10))
103///                .with_session_timeout(Duration::from_secs(10))
104///                .with_reconnect_timeout(Duration::from_secs(1));
105///    let (_cache,_stream) = builder.build("localhost:2181").await?;
106///    Ok(())
107/// }
108impl CacheBuilder {
109    pub fn new(path: impl Into<String>) -> Self {
110        Self {
111            path: path.into(),
112            ..Default::default()
113        }
114    }
115
116    pub fn with_auth(mut self, scheme: String, auth: Vec<u8>) -> Self {
117        self.authes.push(AuthPacket { scheme, auth });
118        self
119    }
120
121    pub fn with_version(mut self, major: u32, minor: u32, patch: u32) -> Self {
122        self.server_version = Version(major, minor, patch);
123        self
124    }
125
126    pub fn with_session_timeout(mut self, timeout: Duration) -> Self {
127        self.session_timeout = timeout;
128        self
129    }
130
131    pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
132        self.connection_timeout = timeout;
133        self
134    }
135
136    pub fn with_reconnect_timeout(mut self, timeout: Duration) -> Self {
137        self.reconnect_timeout = timeout;
138        self
139    }
140
141    pub async fn build(
142        self,
143        addr: impl Into<String>,
144    ) -> Result<(Cache, impl Stream<Item = Event>)> {
145        Cache::new(addr, self).await
146    }
147}
148
149/// Cache will watch root node and it's children nodes recursively
150///```no_run
151/// use futures::StreamExt;
152/// use zookeeper_cache_rust::CacheBuilder;
153/// async fn dox() -> zookeeper_cache_rust::Result<()>{
154///    let (cache,mut stream) = CacheBuilder::default().build("localhost:2181").await?;
155///        tokio::spawn(async move{
156///            while let Some(_event) = stream.next().await{
157///                // handle event
158///            }
159///        });
160///    cache.get("/test").await;
161///    Ok(())
162/// }
163pub struct Cache {
164    addr: String,
165    builder: CacheBuilder,
166    storage: Arc<RwLock<Storage>>,
167    event_sender: tokio::sync::mpsc::UnboundedSender<Event>,
168    token: CancellationToken,
169}
170
171impl Drop for Cache {
172    fn drop(&mut self) {
173        self.token.cancel();
174    }
175}
176
177impl Cache {
178    pub async fn new(
179        addr: impl Into<String>,
180        builder: CacheBuilder,
181    ) -> Result<(Self, impl Stream<Item = Event>)> {
182        let mut connector: zookeeper_client::Connector = (&builder).into();
183        let addr = addr.into();
184        let client = connector.connect(&addr).await?;
185        let storage = Arc::new(RwLock::new(Storage::new(builder.path.clone())));
186        let (sender, watcher) = tokio::sync::mpsc::unbounded_channel();
187        let events = EventStream { watcher };
188        let cache = Self {
189            addr,
190            builder: builder.clone(),
191            storage,
192            event_sender: sender,
193            token: CancellationToken::new(),
194        };
195        let (sender, watcher) = tokio::sync::mpsc::unbounded_channel();
196        Self::init_nodes(
197            &client,
198            &builder.path,
199            cache.storage.write().await.deref_mut(),
200            &sender,
201            &cache.event_sender,
202        )
203        .await?;
204        cache.watch(client, sender, watcher).await;
205        Ok((cache, events))
206    }
207
208    /// Get data and stat through path
209    ///```no_run
210    /// use zookeeper_cache_rust::CacheBuilder;
211    /// async fn dox()->zookeeper_cache_rust::Result<()>{
212    ///    let (cache, _stream) = CacheBuilder::default().build("localhost:2181").await?;
213    ///    cache.get("/test").await;
214    ///    Ok(())
215    /// }
216    /// ```
217    pub async fn get(&self, path: &str) -> Option<SharedChildData> {
218        self.storage.read().await.data.get(path).cloned()
219    }
220
221    async fn init_nodes(
222        client: &zookeeper_client::Client,
223        path: &str,
224        storage: &mut Storage,
225        sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
226        event_sender: &tokio::sync::mpsc::UnboundedSender<Event>,
227    ) -> Result<()> {
228        let new = Arc::new(RwLock::new(Storage::new(path.to_string())));
229        Self::fetch_all(client, path, &mut new.write().await, sender, true).await?;
230        // send events of existed node
231        let new = new.write().await;
232        Self::compare_storage(path, storage, &new, event_sender).await;
233        storage.replace(new.data.clone(), new.tree.clone());
234        Ok(())
235    }
236
237    async fn watch(
238        &self,
239        mut client: zookeeper_client::Client,
240        sender: tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
241        mut watcher: tokio::sync::mpsc::UnboundedReceiver<WatchedEvent>,
242    ) {
243        let addr = self.addr.clone();
244        let storage = self.storage.clone();
245        let sender = sender.clone();
246        let builder = self.builder.clone();
247        let event_sender = self.event_sender.clone();
248        let token = self.token.clone();
249        tokio::spawn(async move {
250            let mut control = HandleControl::Handle;
251            loop {
252                tokio::select! {
253                    _ = token.cancelled() => {
254                        return
255                    }
256                    event = watcher.recv() => {
257                        match event{
258                            Some(event) => {
259                                match control {
260                                    HandleControl::Handle => {},
261                                    HandleControl::Continue => {
262                                        if event.event_type == EventType::Session && event.session_state.is_terminated(){
263                                            continue;
264                                        } else {
265                                            control = HandleControl::Handle;
266                                        }
267                                    }
268                                };
269                                if let Some(reconnect) = Self::handle_event(&addr, &client, &builder, &storage, event, &sender, &event_sender, &token).await{
270                                    client = reconnect;
271                                    // to ignore the other session expired events
272                                    control = HandleControl::Continue;
273                                }
274                            }
275                            None => break
276                        }
277                    }
278                }
279            }
280        });
281    }
282
283    #[allow(clippy::too_many_arguments)]
284    async fn handle_event(
285        addr: &str,
286        client: &zookeeper_client::Client,
287        builder: &CacheBuilder,
288        storage: &Arc<RwLock<Storage>>,
289        event: WatchedEvent,
290        sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
291        event_sender: &tokio::sync::mpsc::UnboundedSender<Event>,
292        token: &CancellationToken,
293    ) -> Option<zookeeper_client::Client> {
294        match event.event_type {
295            EventType::Session => {
296                if let Some(client) =
297                    Self::handle_session(addr, builder, storage, event, sender, event_sender, token)
298                        .await
299                {
300                    return Some(client);
301                }
302            }
303            EventType::NodeDeleted => {
304                Self::handle_node_deleted(storage, event, event_sender).await;
305            }
306            EventType::NodeDataChanged => {
307                Self::handle_node_data_changed(client, storage, event, sender, event_sender).await;
308            }
309            EventType::NodeChildrenChanged => {
310                Self::handle_children_change(client, storage, event, sender, event_sender).await;
311            }
312            EventType::NodeCreated => {
313                Self::handle_node_created(client, storage, event, sender, event_sender).await;
314            }
315        }
316        None
317    }
318
319    async fn handle_session(
320        addr: &str,
321        builder: &CacheBuilder,
322        storage: &Arc<RwLock<Storage>>,
323        event: WatchedEvent,
324        sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
325        event_sender: &tokio::sync::mpsc::UnboundedSender<Event>,
326        token: &CancellationToken,
327    ) -> Option<zookeeper_client::Client> {
328        // todo add log
329        match event.session_state {
330            SessionState::Expired | SessionState::Closed => {
331                let mut interval = tokio::time::interval(builder.reconnect_timeout);
332                let mut connector: zookeeper_client::Connector = builder.into();
333                let client = loop {
334                    tokio::select! {
335                        _ = token.cancelled() => {
336                            return None
337                        }
338                        _ = interval.tick() => {
339                             match connector.connect(addr).await {
340                                Ok(zk) => break zk,
341                                Err(_err) => {
342                                }
343                            };
344                        }
345                    }
346                };
347                {
348                    loop {
349                        match Self::init_nodes(
350                            &client,
351                            &builder.path,
352                            storage.write().await.deref_mut(),
353                            sender,
354                            event_sender,
355                        )
356                        .await
357                        {
358                            Ok(_) => break,
359                            Err(_err) => {
360                                interval.tick().await;
361                            }
362                        }
363                    }
364                }
365                return Some(client);
366            }
367            _ => {}
368        };
369        None
370    }
371
372    /// only used when root node be created
373    async fn handle_node_created(
374        client: &zookeeper_client::Client,
375        storage: &Arc<RwLock<Storage>>,
376        event: WatchedEvent,
377        sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
378        event_sender: &tokio::sync::mpsc::UnboundedSender<Event>,
379    ) {
380        let mut storage = storage.write().await;
381        if let Ok(status) = Self::get_root_node(client, &event.path, &mut storage, sender).await {
382            match status {
383                RootStatus::Ephemeral(data) => {
384                    let _ = event_sender.send(Event::Add(data));
385                }
386                RootStatus::Persistent(data) => {
387                    if let Err(err) = Self::list_children(client, &event.path, sender).await {
388                        debug_assert_eq!(err, zookeeper_client::Error::NoNode);
389                    }
390                    let _ = event_sender.send(Event::Add(data));
391                }
392                _ => {}
393            }
394        }
395    }
396
397    async fn handle_node_deleted(
398        storage: &Arc<RwLock<Storage>>,
399        event: WatchedEvent,
400        event_sender: &tokio::sync::mpsc::UnboundedSender<Event>,
401    ) {
402        let mut storage = storage.write().await;
403        storage.tree.remove_child(&event.path);
404        match storage.data.get(&event.path) {
405            None => {}
406            Some(_data) => {}
407        }
408        match storage.data.remove(&event.path) {
409            None => {}
410            Some(child_data) => {
411                let _ = event_sender.send(Event::Delete(child_data));
412            }
413        }
414    }
415
416    async fn handle_node_data_changed(
417        client: &zookeeper_client::Client,
418        storage: &Arc<RwLock<Storage>>,
419        event: WatchedEvent,
420        sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
421        event_sender: &tokio::sync::mpsc::UnboundedSender<Event>,
422    ) {
423        let mut storage = storage.write().await;
424        let old = storage.data.get(&event.path).unwrap().clone();
425        if let Err(err) = Self::get_data(client, &event.path, &mut storage, sender).await {
426            debug_assert_eq!(err, zookeeper_client::Error::NoNode);
427            // data deleted
428            storage.tree.remove_child(&event.path);
429            let child_data = storage.data.remove(&event.path).unwrap();
430            let _ = event_sender.send(Event::Delete(child_data));
431            return;
432        };
433        let new = storage.data.get(&event.path).unwrap().clone();
434        let _ = event_sender.send(Event::Update { old, new });
435    }
436
437    async fn handle_children_change(
438        client: &zookeeper_client::Client,
439        storage: &Arc<RwLock<Storage>>,
440        event: WatchedEvent,
441        sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
442        event_sender: &tokio::sync::mpsc::UnboundedSender<Event>,
443    ) {
444        let old_children = storage
445            .read()
446            .await
447            .tree
448            .children(&event.path)
449            .into_iter()
450            .map(|child| child.to_string())
451            .collect::<Vec<_>>();
452        let new_children = match Self::list_children(client, &event.path, sender).await {
453            Ok(children) => children
454                .iter()
455                .map(|child| make_path(&event.path, child))
456                .collect::<Vec<_>>(),
457            Err(err) => {
458                debug_assert_eq!(err, zookeeper_client::Error::NoNode);
459                return;
460            }
461        };
462        let (added, _) = compare(&old_children, &new_children);
463        //only handle node added
464        let added = added
465            .into_iter()
466            .map(|added| {
467                let zk = client.clone();
468                let path = event.path.clone();
469                let sender = sender.clone();
470                let event_sender = event_sender.clone();
471                (zk, storage, path, added, sender, event_sender)
472            })
473            .collect::<Vec<_>>();
474        stream::iter(added)
475            .for_each_concurrent(
476                // we fetch children through stream
477                20,
478                |(zk, storage, parent, child_path, sender, event_sender)| async move {
479                    let mut storage = storage.write().await;
480                    if let Err(err) =
481                        Self::get_data(&zk, &child_path, &mut storage, &sender.clone()).await
482                    {
483                        debug_assert_eq!(err, zookeeper_client::Error::NoNode);
484                        return;
485                    }
486                    storage.tree.add_child(&parent, child_path.clone());
487                    let child_data = storage.data.get(&child_path).unwrap();
488                    let _ = event_sender.send(Event::Add(child_data.clone()));
489                },
490            )
491            .await;
492    }
493
494    async fn get_data(
495        client: &zookeeper_client::Client,
496        path: &str,
497        storage: &mut RwLockWriteGuard<'_, Storage>,
498        sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
499    ) -> std::result::Result<SharedChildData, zookeeper_client::Error> {
500        let (data, stat, watcher) = client.get_and_watch_data(path).await?;
501        let data = Arc::new(ChildData {
502            path: path.to_string(),
503            data,
504            stat,
505        });
506        storage.data.insert(path.to_string(), data.clone());
507        {
508            let sender = sender.clone();
509            tokio::spawn(async move {
510                let _ = sender.send(watcher.changed().await);
511            });
512        }
513        Ok(data)
514    }
515
516    async fn list_children(
517        client: &zookeeper_client::Client,
518        path: &str,
519        sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
520    ) -> std::result::Result<Vec<String>, zookeeper_client::Error> {
521        let (children, watcher) = client.list_and_watch_children(path).await?;
522        {
523            let sender = sender.clone();
524            tokio::spawn(async move {
525                let _ = sender.send(watcher.changed().await);
526            });
527        }
528        Ok(children)
529    }
530
531    /// get the root node, if it exists return true, or return false
532    #[async_recursion]
533    async fn get_root_node(
534        client: &zookeeper_client::Client,
535        path: &str,
536        storage: &mut RwLockWriteGuard<'_, Storage>,
537        sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
538    ) -> std::result::Result<RootStatus, zookeeper_client::Error> {
539        match client.check_and_watch_stat(path).await? {
540            (None, watcher) => {
541                let sender = sender.clone();
542                tokio::spawn(async move {
543                    let _ = sender.send(watcher.changed().await);
544                });
545                Ok(RootStatus::NotExist)
546            }
547            (Some(_), _) => {
548                match Self::get_data(client, path, storage, sender).await {
549                    Ok(data) if data.stat.ephemeral_owner != 0 => {
550                        Ok(RootStatus::Ephemeral(data.clone()))
551                    }
552                    Ok(data) => Ok(RootStatus::Persistent(data.clone())),
553                    Err(err) => {
554                        debug_assert_eq!(err, zookeeper_client::Error::NoNode);
555                        // if  node exist -> node deleted, we need to repeat this function
556                        Self::get_root_node(client, path, storage, sender).await
557                    }
558                }
559            }
560        }
561    }
562
563    #[async_recursion]
564    async fn fetch_all(
565        client: &zookeeper_client::Client,
566        path: &str,
567        storage: &mut RwLockWriteGuard<Storage>,
568        sender: &tokio::sync::mpsc::UnboundedSender<WatchedEvent>,
569        root: bool,
570    ) -> std::result::Result<(), zookeeper_client::Error> {
571        let persistent = if root {
572            matches!(
573                Self::get_root_node(client, path, storage, sender).await?,
574                RootStatus::Persistent(_)
575            )
576        } else {
577            Self::get_data(client, path, storage, sender)
578                .await?
579                .stat
580                .ephemeral_owner
581                == 0
582        };
583        if persistent {
584            let children = match Self::list_children(client, path, sender).await {
585                Ok(children) => children,
586                Err(_) => return Ok(()),
587            };
588            storage.tree.add_children(
589                path,
590                children
591                    .iter()
592                    .map(|child| make_path(path, child))
593                    .collect(),
594            );
595            for child in children.iter() {
596                if let Err(zookeeper_client::Error::NoNode) = Self::fetch_all(
597                    client,
598                    make_path(path, child).as_str(),
599                    storage,
600                    sender,
601                    false,
602                )
603                .await
604                {
605                    continue;
606                }
607            }
608        }
609        Ok(())
610    }
611
612    #[async_recursion]
613    async fn compare_storage(
614        path: &str,
615        old: &Storage,
616        new: &Storage,
617        sender: &tokio::sync::mpsc::UnboundedSender<Event>,
618    ) {
619        let old_data = old.data.get(path);
620        let new_data = new.data.get(path);
621        match (old_data, new_data) {
622            (Some(data), None) => {
623                let _ = sender.send(Event::Delete(data.clone()));
624            }
625            (None, Some(data)) => {
626                let _ = sender.send(Event::Add(data.clone()));
627            }
628            (Some(old), Some(new)) => {
629                if !old.eq(new) {
630                    let _ = sender.send(Event::Update {
631                        old: old.clone(),
632                        new: new.clone(),
633                    });
634                }
635            }
636            _ => {}
637        }
638        let mut old_children = old.tree.children(path);
639        let mut new_children = new.tree.children(path);
640        old_children.append(&mut new_children);
641        let children = old_children.into_iter().collect::<HashSet<_>>();
642        for child in children.iter() {
643            Self::compare_storage(child, old, new, sender).await;
644        }
645    }
646}
647
648fn make_path(parent: &str, child: &str) -> String {
649    if let Some('/') = parent.chars().last() {
650        format!("{}{}", parent, child)
651    } else {
652        format!("{}/{}", parent, child)
653    }
654}
655
656fn compare(old: &[String], new: &[String]) -> (Vec<String>, Vec<String>) {
657    let old_map = old.iter().collect::<HashSet<_>>();
658    let new_map = new.iter().collect::<HashSet<_>>();
659    let and = &new_map & &old_map;
660    (
661        (&new_map ^ &and)
662            .into_iter()
663            .map(|s| s.to_string())
664            .collect(),
665        (&old_map ^ &and)
666            .into_iter()
667            .map(|s| s.to_string())
668            .collect(),
669    )
670}
671
672#[derive(Clone, Debug)]
673enum RootStatus {
674    NotExist,
675    Ephemeral(SharedChildData),
676    Persistent(SharedChildData),
677}
678
679#[derive(Clone, Debug)]
680enum HandleControl {
681    Handle,
682    Continue,
683}
684
685#[cfg(test)]
686mod tests {
687    #[test]
688    fn compare() {
689        let old = ["1".to_string(), "2".to_string(), "3".to_string()];
690        let new = ["2".to_string(), "3".to_string(), "4".to_string()];
691        let (added, deleted) = super::compare(&old, &new);
692        assert_eq!(added, vec!["4".to_string()]);
693        assert_eq!(deleted, vec!["1".to_string()]);
694    }
695}