1mod watcher;
2
3use std::borrow::Cow;
4use std::fmt::Write as _;
5use std::future::Future;
6use std::mem::ManuallyDrop;
7use std::sync::Arc;
8use std::time::Duration;
9
10use const_format::formatcp;
11use derive_where::derive_where;
12use either::{Either, Left, Right};
13use futures::channel::mpsc;
14use ignore_result::Ignore;
15use thiserror::Error;
16use tracing::instrument;
17
18pub use self::watcher::{OneshotWatcher, PersistentWatcher, StateWatcher};
19use super::session::{Depot, MarshalledRequest, Request, Session, SessionOperation, WatchReceiver};
20use crate::acl::{Acl, Acls, AuthUser};
21use crate::chroot::{Chroot, ChrootPath, OwnedChroot};
22use crate::endpoint::{self, IterableEndpoints};
23use crate::error::Error;
24use crate::proto::{
25 self,
26 AuthPacket,
27 CheckVersionRequest,
28 CreateRequest,
29 DeleteRequest,
30 ExistsRequest,
31 GetAclResponse,
32 GetChildren2Response,
33 GetChildrenRequest,
34 GetRequest,
35 MultiHeader,
36 MultiReadResponse,
37 MultiWriteResponse,
38 OpCode,
39 PersistentWatchRequest,
40 ReconfigRequest,
41 RequestBuffer,
42 RequestHeader,
43 SetAclRequest,
44 SetDataRequest,
45 SyncRequest,
46};
47pub use crate::proto::{EnsembleUpdate, Stat};
48use crate::record::{self, Record, StaticRecord};
49#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
50use crate::sasl::SaslOptions;
51use crate::session::StateReceiver;
52pub use crate::session::{EventType, SessionId, SessionInfo, SessionState, WatchedEvent};
53#[cfg(feature = "tls")]
54use crate::tls::TlsOptions;
55use crate::util;
56
57pub(crate) type Result<T, E = Error> = std::result::Result<T, E>;
58
59#[derive(Clone, Copy, Debug, PartialEq, Eq)]
62pub enum CreateMode {
63 Persistent,
64 PersistentSequential,
65 Ephemeral,
66 EphemeralSequential,
67 Container,
68}
69
70impl CreateMode {
71 pub const fn with_acls(self, acls: Acls<'_>) -> CreateOptions<'_> {
73 CreateOptions { mode: self, acls, ttl: None }
74 }
75
76 fn is_sequential(self) -> bool {
77 self == CreateMode::PersistentSequential || self == CreateMode::EphemeralSequential
78 }
79
80 fn is_persistent(self) -> bool {
81 self == Self::Persistent || self == Self::PersistentSequential
82 }
83
84 fn is_ephemeral(self) -> bool {
85 self == Self::Ephemeral || self == Self::EphemeralSequential
86 }
87
88 fn is_container(self) -> bool {
89 self == CreateMode::Container
90 }
91
92 fn as_flags(self, ttl: bool) -> i32 {
93 use CreateMode::*;
94 match self {
95 Persistent => {
96 if ttl {
97 5
98 } else {
99 0
100 }
101 },
102 PersistentSequential => {
103 if ttl {
104 6
105 } else {
106 2
107 }
108 },
109 Ephemeral => 1,
110 EphemeralSequential => 3,
111 Container => 4,
112 }
113 }
114}
115
116#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
118pub enum AddWatchMode {
119 Persistent,
121
122 PersistentRecursive,
124}
125
126impl From<AddWatchMode> for proto::AddWatchMode {
127 fn from(mode: AddWatchMode) -> proto::AddWatchMode {
128 match mode {
129 AddWatchMode::Persistent => proto::AddWatchMode::Persistent,
130 AddWatchMode::PersistentRecursive => proto::AddWatchMode::PersistentRecursive,
131 }
132 }
133}
134
135#[derive(Clone, Debug)]
137pub struct CreateOptions<'a> {
138 mode: CreateMode,
139 acls: Acls<'a>,
140 ttl: Option<Duration>,
141}
142
143const TTL_MAX_MILLIS: u128 = 0x00FFFFFFFFFF;
147
148impl<'a> CreateOptions<'a> {
149 pub const fn with_ttl(mut self, ttl: Duration) -> Self {
151 self.ttl = Some(ttl);
152 self
153 }
154
155 fn validate(&'a self) -> Result<()> {
156 if let Some(ref ttl) = self.ttl {
157 if !self.mode.is_persistent() {
158 return Err(Error::BadArguments(&"ttl can only be specified with persistent node"));
159 } else if ttl.is_zero() {
160 return Err(Error::BadArguments(&"ttl is zero"));
161 } else if ttl.as_millis() > TTL_MAX_MILLIS {
162 return Err(Error::BadArguments(&formatcp!("ttl cannot larger than {}", TTL_MAX_MILLIS)));
163 }
164 }
165 if self.acls.is_empty() {
166 return Err(Error::InvalidAcl);
167 }
168 Ok(())
169 }
170
171 fn validate_as_directory(&self) -> Result<()> {
172 self.validate()?;
173 if self.mode.is_ephemeral() {
174 return Err(Error::BadArguments(&"directory node must not be ephemeral"));
175 } else if self.mode.is_sequential() {
176 return Err(Error::BadArguments(&"directory node must not be sequential"));
177 }
178 Ok(())
179 }
180}
181
182#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
186pub struct CreateSequence(i64);
187
188impl std::fmt::Display for CreateSequence {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 if self.0 <= i32::MAX.into() {
194 write!(f, "{:010}", self.0)
195 } else {
196 write!(f, "{:019}", self.0)
197 }
198 }
199}
200
201impl CreateSequence {
202 pub fn into_i64(self) -> i64 {
203 self.0
204 }
205}
206
207#[derive(Clone, Debug)]
221pub struct Client {
222 chroot: OwnedChroot,
223 version: Version,
224 session: SessionInfo,
225 session_timeout: Duration,
226 requester: Arc<mpsc::UnboundedSender<Request>>,
227 state_watcher: StateWatcher,
228}
229
230impl Client {
231 const CONFIG_NODE: &'static str = "/zookeeper/config";
232
233 pub async fn connect(cluster: &str) -> Result<Self> {
235 Self::connector().connect(cluster).await
236 }
237
238 pub fn connector() -> Connector {
240 Connector::new()
241 }
242
243 pub(crate) fn new(
244 chroot: OwnedChroot,
245 version: Version,
246 session: SessionInfo,
247 timeout: Duration,
248 requester: Arc<mpsc::UnboundedSender<Request>>,
249 state_watcher: StateWatcher,
250 ) -> Client {
251 Client { chroot, version, session, session_timeout: timeout, requester, state_watcher }
252 }
253
254 fn validate_path<'a>(&'a self, path: &'a str) -> Result<ChrootPath<'a>> {
255 ChrootPath::new(self.chroot.as_ref(), path, false)
256 }
257
258 fn validate_sequential_path<'a>(&'a self, path: &'a str) -> Result<ChrootPath<'a>> {
259 ChrootPath::new(self.chroot.as_ref(), path, true)
260 }
261
262 pub fn path(&self) -> &str {
264 self.chroot.path()
265 }
266
267 pub fn session(&self) -> &SessionInfo {
269 &self.session
270 }
271
272 pub fn session_id(&self) -> SessionId {
274 self.session().id()
275 }
276
277 pub fn into_session(self) -> SessionInfo {
279 self.session
280 }
281
282 pub fn session_timeout(&self) -> Duration {
284 self.session_timeout
285 }
286
287 pub fn state(&self) -> SessionState {
289 self.state_watcher.peek_state()
290 }
291
292 pub fn state_watcher(&self) -> StateWatcher {
294 let mut watcher = self.state_watcher.clone();
295 watcher.state();
296 watcher
297 }
298
299 pub fn chroot<'a>(mut self, path: impl Into<Cow<'a, str>>) -> std::result::Result<Client, Client> {
307 if self.chroot.chroot(path) {
308 Ok(self)
309 } else {
310 Err(self)
311 }
312 }
313
314 fn send_request(&self, code: OpCode, body: &impl Record) -> StateReceiver {
315 let request = MarshalledRequest::new(code, body);
316 self.send_marshalled_request(request)
317 }
318
319 fn send_marshalled_request(&self, request: MarshalledRequest) -> StateReceiver {
320 let (operation, receiver) = SessionOperation::new_marshalled(request).with_responser();
321 if let Err(err) = self.requester.unbounded_send(operation.into()) {
322 let state = self.state();
323 err.into_inner().into_responser().send(Err(state.to_error()));
324 }
325 receiver
326 }
327
328 async fn wait<T, E, F>(result: std::result::Result<F, E>) -> std::result::Result<T, E>
329 where
330 F: Future<Output = std::result::Result<T, E>>, {
331 match result {
332 Err(err) => Err(err),
333 Ok(future) => future.await,
334 }
335 }
336
337 async fn resolve<T, E, F>(result: std::result::Result<Either<F, T>, E>) -> std::result::Result<T, E>
338 where
339 F: Future<Output = std::result::Result<T, E>>, {
340 match result {
341 Err(err) => Err(err),
342 Ok(Right(r)) => Ok(r),
343 Ok(Left(future)) => future.await,
344 }
345 }
346
347 async fn map_wait<T, U, Fu, Fn>(result: Result<Fu>, f: Fn) -> Result<U>
348 where
349 Fu: Future<Output = Result<T>>,
350 Fn: FnOnce(T) -> U, {
351 match result {
352 Err(err) => Err(err),
353 Ok(future) => match future.await {
354 Err(err) => Err(err),
355 Ok(t) => Ok(f(t)),
356 },
357 }
358 }
359
360 async fn retry_on_connection_loss<T, F>(operation: impl Fn() -> F) -> Result<T>
361 where
362 F: Future<Output = Result<T>>, {
363 loop {
364 let future = operation();
365 return match future.await {
366 Err(Error::ConnectionLoss) => continue,
367 result => result,
368 };
369 }
370 }
371
372 fn parse_sequence(client_path: &str, prefix: &str) -> Result<CreateSequence> {
373 if let Some(sequence_path) = client_path.strip_prefix(prefix) {
374 match sequence_path.parse::<i64>() {
375 Err(_) => Err(Error::UnexpectedError(format!("sequential node get no i32 path {client_path}"))),
376 Ok(i) => Ok(CreateSequence(i)),
377 }
378 } else {
379 Err(Error::UnexpectedError(format!("sequential path {client_path} does not contain prefix path {prefix}",)))
380 }
381 }
382
383 pub async fn mkdir(&self, path: &str, options: &CreateOptions<'_>) -> Result<()> {
395 options.validate_as_directory()?;
396 self.mkdir_internally(path, options).await
397 }
398
399 async fn mkdir_internally(&self, path: &str, options: &CreateOptions<'_>) -> Result<()> {
400 let mut j = path.len();
401 loop {
402 match self.create(&path[..j], Default::default(), options).await {
403 Ok(_) | Err(Error::NodeExists) => {
404 if j >= path.len() {
405 return Ok(());
406 } else if let Some(i) = path[j + 1..].find('/') {
407 j = j + 1 + i;
408 } else {
409 j = path.len();
410 }
411 },
412 Err(Error::NoNode) => {
413 let i = path[..j].rfind('/').unwrap();
414 if i == 0 {
415 return Err(Error::NoNode);
417 }
418 j = i;
419 },
420 Err(err) => return Err(err),
421 }
422 }
423 }
424
425 pub fn create<'a: 'f, 'b: 'f, 'f>(
439 &'a self,
440 path: &'b str,
441 data: &[u8],
442 options: &CreateOptions<'_>,
443 ) -> impl Future<Output = Result<(Stat, CreateSequence)>> + Send + 'f {
444 Self::wait(self.create_internally(path, data, options))
445 }
446
447 fn create_internally<'a: 'f, 'b: 'f, 'f>(
448 &'a self,
449 path: &'b str,
450 data: &[u8],
451 options: &CreateOptions<'_>,
452 ) -> Result<impl Future<Output = Result<(Stat, CreateSequence)>> + Send + 'f> {
453 options.validate()?;
454 let create_mode = options.mode;
455 let sequential = create_mode.is_sequential();
456 let chroot_path = if sequential { self.validate_sequential_path(path)? } else { self.validate_path(path)? };
457 if chroot_path.is_root() {
458 return Err(Error::BadArguments(&"can not create root node"));
459 }
460 let ttl = options.ttl.map(|ttl| ttl.as_millis() as i64).unwrap_or(0);
461 let op_code = if ttl != 0 {
462 OpCode::CreateTtl
463 } else if create_mode.is_container() {
464 OpCode::CreateContainer
465 } else if self.version >= Version(3, 5, 0) {
466 OpCode::Create2
467 } else {
468 OpCode::Create
469 };
470 let flags = create_mode.as_flags(ttl != 0);
471 let request = CreateRequest { path: chroot_path, data, acls: options.acls, flags, ttl };
472 let receiver = self.send_request(op_code, &request);
473 Ok(async move {
474 let (body, _) = receiver.await?;
475 let mut buf = body.as_slice();
476 let server_path = record::unmarshal_entity::<&str>(&"server path", &mut buf)?;
477 let client_path = util::strip_root_path(server_path, self.chroot.root())?;
478 let sequence = if sequential { Self::parse_sequence(client_path, path)? } else { CreateSequence(-1) };
479 let stat =
480 if op_code == OpCode::Create { Stat::new_invalid() } else { record::unmarshal::<Stat>(&mut buf)? };
481 Ok((stat, sequence))
482 })
483 }
484
485 pub fn delete(&self, path: &str, expected_version: Option<i32>) -> impl Future<Output = Result<()>> + Send {
492 Self::wait(self.delete_internally(path, expected_version))
493 }
494
495 fn delete_internally(&self, path: &str, expected_version: Option<i32>) -> Result<impl Future<Output = Result<()>>> {
496 let chroot_path = self.validate_path(path)?;
497 if chroot_path.is_root() {
498 return Err(Error::BadArguments(&"can not delete root node"));
499 }
500 let request = DeleteRequest { path: chroot_path, version: expected_version.unwrap_or(-1) };
501 let receiver = self.send_request(OpCode::Delete, &request);
502 Ok(async move {
503 receiver.await?;
504 Ok(())
505 })
506 }
507
508 fn delete_background(self, path: String) {
510 asyncs::spawn(async move {
511 self.delete_foreground(&path).await;
512 });
513 }
514
515 async fn delete_foreground(&self, path: &str) {
516 Client::retry_on_connection_loss(|| self.delete(path, None)).await.ignore();
517 }
518
519 fn delete_ephemeral_background(self, prefix: String, unique: bool) {
520 asyncs::spawn(async move {
521 let (parent, tree, name) = util::split_path(&prefix);
522 let mut children = Self::retry_on_connection_loss(|| self.list_children(parent)).await?;
523 if unique {
524 if let Some(i) = children.iter().position(|s| s.starts_with(name)) {
525 self.delete_foreground(&children[i]).await;
526 };
527 return Ok::<(), Error>(());
528 }
529 children.retain(|s| s.starts_with(name));
530 for child in children.iter_mut() {
531 child.insert_str(0, tree);
532 }
533 let results = Self::retry_on_connection_loss(|| {
534 let mut reader = self.new_multi_reader();
535 for child in children.iter() {
536 reader.add_get_data(child).unwrap();
537 }
538 reader.commit()
539 })
540 .await?;
541 for (i, result) in results.into_iter().enumerate() {
542 let MultiReadResult::Data { stat, .. } = result else {
543 continue;
545 };
546 if stat.ephemeral_owner == self.session_id().0 {
547 self.delete_foreground(&children[i]).await;
548 break;
549 }
550 }
551 Ok(())
552 });
553 }
554
555 fn get_data_internally(
556 &self,
557 chroot: Chroot,
558 path: &str,
559 watch: bool,
560 ) -> Result<impl Future<Output = Result<(Vec<u8>, Stat, WatchReceiver)>> + Send> {
561 let chroot_path = ChrootPath::new(chroot, path, false)?;
562 let request = GetRequest { path: chroot_path, watch };
563 let receiver = self.send_request(OpCode::GetData, &request);
564 Ok(async move {
565 let (mut body, watcher) = receiver.await?;
566 let data_len = body.len() - Stat::record_len();
567 let mut stat_buf = &body[data_len..];
568 let stat = record::unmarshal(&mut stat_buf)?;
569 body.truncate(data_len);
570 drop(body.drain(..4));
571 Ok((body, stat, watcher))
572 })
573 }
574
575 pub fn get_data(&self, path: &str) -> impl Future<Output = Result<(Vec<u8>, Stat)>> + Send {
580 let result = self.get_data_internally(self.chroot.as_ref(), path, false);
581 Self::map_wait(result, |(data, stat, _)| (data, stat))
582 }
583
584 pub fn get_and_watch_data(
594 &self,
595 path: &str,
596 ) -> impl Future<Output = Result<(Vec<u8>, Stat, OneshotWatcher)>> + Send + '_ {
597 let result = self.get_data_internally(self.chroot.as_ref(), path, true);
598 Self::map_wait(result, |(data, stat, watcher)| (data, stat, watcher.into_oneshot(&self.chroot)))
599 }
600
601 fn check_stat_internally(
602 &self,
603 path: &str,
604 watch: bool,
605 ) -> Result<impl Future<Output = Result<(Option<Stat>, WatchReceiver)>>> {
606 let chroot_path = self.validate_path(path)?;
607 let request = ExistsRequest { path: chroot_path, watch };
608 let receiver = self.send_request(OpCode::Exists, &request);
609 Ok(async move {
610 let (body, watcher) = receiver.await?;
611 let mut buf = body.as_slice();
612 let stat = record::try_deserialize(&mut buf)?;
613 Ok((stat, watcher))
614 })
615 }
616
617 pub fn check_stat(&self, path: &str) -> impl Future<Output = Result<Option<Stat>>> + Send {
619 Self::map_wait(self.check_stat_internally(path, false), |(stat, _)| stat)
620 }
621
622 pub fn check_and_watch_stat(
629 &self,
630 path: &str,
631 ) -> impl Future<Output = Result<(Option<Stat>, OneshotWatcher)>> + Send + '_ {
632 let result = self.check_stat_internally(path, true);
633 Self::map_wait(result, |(stat, watcher)| (stat, watcher.into_oneshot(&self.chroot)))
634 }
635
636 pub fn set_data(
643 &self,
644 path: &str,
645 data: &[u8],
646 expected_version: Option<i32>,
647 ) -> impl Future<Output = Result<Stat>> + Send {
648 Self::wait(self.set_data_internally(path, data, expected_version))
649 }
650
651 pub fn set_data_internally(
652 &self,
653 path: &str,
654 data: &[u8],
655 expected_version: Option<i32>,
656 ) -> Result<impl Future<Output = Result<Stat>>> {
657 let chroot_path = self.validate_path(path)?;
658 let request = SetDataRequest { path: chroot_path, data, version: expected_version.unwrap_or(-1) };
659 let receiver = self.send_request(OpCode::SetData, &request);
660 Ok(async move {
661 let (body, _) = receiver.await?;
662 let mut buf = body.as_slice();
663 let stat: Stat = record::unmarshal(&mut buf)?;
664 Ok(stat)
665 })
666 }
667
668 fn list_children_internally(
669 &self,
670 path: &str,
671 watch: bool,
672 ) -> Result<impl Future<Output = Result<(Vec<String>, WatchReceiver)>>> {
673 let chroot_path = self.validate_path(path)?;
674 let request = GetChildrenRequest { path: chroot_path, watch };
675 let receiver = self.send_request(OpCode::GetChildren, &request);
676 Ok(async move {
677 let (body, watcher) = receiver.await?;
678 let mut buf = body.as_slice();
679 let children = record::unmarshal_entity::<Vec<String>>(&"children paths", &mut buf)?;
680 Ok((children, watcher))
681 })
682 }
683
684 pub fn list_children(&self, path: &str) -> impl Future<Output = Result<Vec<String>>> + Send + '_ {
689 Self::map_wait(self.list_children_internally(path, false), |(children, _)| children)
690 }
691
692 pub fn list_and_watch_children(
703 &self,
704 path: &str,
705 ) -> impl Future<Output = Result<(Vec<String>, OneshotWatcher)>> + Send + '_ {
706 let result = self.list_children_internally(path, true);
707 Self::map_wait(result, |(children, watcher)| (children, watcher.into_oneshot(&self.chroot)))
708 }
709
710 fn get_children_internally(
711 &self,
712 path: &str,
713 watch: bool,
714 ) -> Result<impl Future<Output = Result<(Vec<String>, Stat, WatchReceiver)>>> {
715 let chroot_path = self.validate_path(path)?;
716 let request = GetChildrenRequest { path: chroot_path, watch };
717 let receiver = self.send_request(OpCode::GetChildren2, &request);
718 Ok(async move {
719 let (body, watcher) = receiver.await?;
720 let mut buf = body.as_slice();
721 let response = record::unmarshal::<GetChildren2Response>(&mut buf)?;
722 Ok((response.children, response.stat, watcher))
723 })
724 }
725
726 pub fn get_children(&self, path: &str) -> impl Future<Output = Result<(Vec<String>, Stat)>> + Send {
731 let result = self.get_children_internally(path, false);
732 Self::map_wait(result, |(children, stat, _)| (children, stat))
733 }
734
735 pub fn get_and_watch_children(
746 &self,
747 path: &str,
748 ) -> impl Future<Output = Result<(Vec<String>, Stat, OneshotWatcher)>> + Send + '_ {
749 let result = self.get_children_internally(path, true);
750 Self::map_wait(result, |(children, stat, watcher)| (children, stat, watcher.into_oneshot(&self.chroot)))
751 }
752
753 pub fn count_descendants_number(&self, path: &str) -> impl Future<Output = Result<usize>> + Send {
758 Self::wait(self.count_descendants_number_internally(path))
759 }
760
761 fn count_descendants_number_internally(&self, path: &str) -> Result<impl Future<Output = Result<usize>>> {
762 let chroot_path = self.validate_path(path)?;
763 let receiver = self.send_request(OpCode::GetAllChildrenNumber, &chroot_path);
764 Ok(async move {
765 let (body, _) = receiver.await?;
766 let mut buf = body.as_slice();
767 let n = record::unmarshal_entity::<i32>(&"all children number", &mut buf)?;
768 Ok(n as usize)
769 })
770 }
771
772 pub fn list_ephemerals(&self, path: &str) -> impl Future<Output = Result<Vec<String>>> + Send + '_ {
779 Self::wait(self.list_ephemerals_internally(path))
780 }
781
782 fn list_ephemerals_internally(&self, path: &str) -> Result<impl Future<Output = Result<Vec<String>>> + Send + '_> {
783 let path = self.validate_path(path)?;
784 let receiver = self.send_request(OpCode::GetEphemerals, &path);
785 Ok(async move {
786 let (body, _) = receiver.await?;
787 let mut buf = body.as_slice();
788 let mut ephemerals = record::unmarshal_entity::<Vec<String>>(&"ephemerals", &mut buf)?;
789 for ephemeral_path in ephemerals.iter_mut() {
790 util::drain_root_path(ephemeral_path, self.chroot.root())?;
791 }
792 Ok(ephemerals)
793 })
794 }
795
796 pub fn get_acl(&self, path: &str) -> impl Future<Output = Result<(Vec<Acl>, Stat)>> + Send + '_ {
801 Self::wait(self.get_acl_internally(path))
802 }
803
804 fn get_acl_internally(&self, path: &str) -> Result<impl Future<Output = Result<(Vec<Acl>, Stat)>>> {
805 let chroot_path = self.validate_path(path)?;
806 let receiver = self.send_request(OpCode::GetACL, &chroot_path);
807 Ok(async move {
808 let (body, _) = receiver.await?;
809 let mut buf = body.as_slice();
810 let response: GetAclResponse = record::unmarshal(&mut buf)?;
811 Ok((response.acl, response.stat))
812 })
813 }
814
815 pub fn set_acl(
821 &self,
822 path: &str,
823 acl: &[Acl],
824 expected_acl_version: Option<i32>,
825 ) -> impl Future<Output = Result<Stat>> + Send + '_ {
826 Self::wait(self.set_acl_internally(path, acl, expected_acl_version))
827 }
828
829 fn set_acl_internally(
830 &self,
831 path: &str,
832 acl: &[Acl],
833 expected_acl_version: Option<i32>,
834 ) -> Result<impl Future<Output = Result<Stat>>> {
835 let chroot_path = self.validate_path(path)?;
836 let request = SetAclRequest { path: chroot_path, acl, version: expected_acl_version.unwrap_or(-1) };
837 let receiver = self.send_request(OpCode::SetACL, &request);
838 Ok(async move {
839 let (body, _) = receiver.await?;
840 let mut buf = body.as_slice();
841 let stat: Stat = record::unmarshal(&mut buf)?;
842 Ok(stat)
843 })
844 }
845
846 pub fn watch(&self, path: &str, mode: AddWatchMode) -> impl Future<Output = Result<PersistentWatcher>> + Send + '_ {
861 Self::wait(self.watch_internally(path, mode))
862 }
863
864 fn watch_internally(
865 &self,
866 path: &str,
867 mode: AddWatchMode,
868 ) -> Result<impl Future<Output = Result<PersistentWatcher>> + Send + '_> {
869 let chroot_path = self.validate_path(path)?;
870 let proto_mode = proto::AddWatchMode::from(mode);
871 let request = PersistentWatchRequest { path: chroot_path, mode: proto_mode.into() };
872 let receiver = self.send_request(OpCode::AddWatch, &request);
873 Ok(async move {
874 let (_, watcher) = receiver.await?;
875 Ok(watcher.into_persistent(&self.chroot))
876 })
877 }
878
879 pub fn sync(&self, path: &str) -> impl Future<Output = Result<()>> + Send + '_ {
890 Self::wait(self.sync_internally(path))
891 }
892
893 fn sync_internally(&self, path: &str) -> Result<impl Future<Output = Result<()>>> {
894 let chroot_path = self.validate_path(path)?;
895 let request = SyncRequest { path: chroot_path };
896 let receiver = self.send_request(OpCode::Sync, &request);
897 Ok(async move {
898 let (body, _) = receiver.await?;
899 let mut buf = body.as_slice();
900 record::unmarshal_entity::<&str>(&"server path", &mut buf)?;
901 Ok(())
902 })
903 }
904
905 pub fn auth(&self, scheme: &str, auth: &[u8]) -> impl Future<Output = Result<()>> + Send + '_ {
919 let request = AuthPacket { scheme, auth };
920 let receiver = self.send_request(OpCode::Auth, &request);
921 async move {
922 receiver.await?;
923 Ok(())
924 }
925 }
926
927 pub fn list_auth_users(&self) -> impl Future<Output = Result<Vec<AuthUser>>> + Send {
937 let receiver = self.send_request(OpCode::WhoAmI, &());
938 async move {
939 let (body, _) = receiver.await?;
940 let mut buf = body.as_slice();
941 let authed_users = record::unmarshal_entity::<Vec<AuthUser>>(&"authed users", &mut buf)?;
942 Ok(authed_users)
943 }
944 }
945
946 pub fn get_config(&self) -> impl Future<Output = Result<(Vec<u8>, Stat)>> + Send {
948 let result = self.get_data_internally(Chroot::default(), Self::CONFIG_NODE, false);
949 Self::map_wait(result, |(data, stat, _)| (data, stat))
950 }
951
952 pub fn get_and_watch_config(&self) -> impl Future<Output = Result<(Vec<u8>, Stat, OneshotWatcher)>> + Send {
954 let result = self.get_data_internally(Chroot::default(), Self::CONFIG_NODE, true);
955 Self::map_wait(result, |(data, stat, watcher)| (data, stat, watcher.into_oneshot(&OwnedChroot::default())))
956 }
957
958 pub fn update_ensemble<'a, I: Iterator<Item = &'a str> + Clone>(
966 &self,
967 update: EnsembleUpdate<'a, I>,
968 expected_zxid: Option<i64>,
969 ) -> impl Future<Output = Result<(Vec<u8>, Stat)>> + Send {
970 let request = ReconfigRequest { update, version: expected_zxid.unwrap_or(-1) };
971 let receiver = self.send_request(OpCode::Reconfig, &request);
972 async move {
973 let (mut body, _) = receiver.await?;
974 let mut buf = body.as_slice();
975 let data: &str = record::unmarshal_entity(&"reconfig data", &mut buf)?;
976 let stat = record::unmarshal_entity(&"reconfig stat", &mut buf)?;
977 let data_len = data.len();
978 body.truncate(data_len + 4);
979 drop(body.drain(..4));
980 Ok((body, stat))
981 }
982 }
983
984 pub fn new_multi_reader(&self) -> MultiReader<'_> {
986 MultiReader::new(self)
987 }
988
989 pub fn new_multi_writer(&self) -> MultiWriter<'_> {
991 MultiWriter::new(self)
992 }
993
994 pub fn new_check_writer(&self, path: &str, version: Option<i32>) -> Result<CheckWriter<'_>> {
997 let mut writer = self.new_multi_writer();
998 writer.add_check_version(path, version.unwrap_or(-1))?;
999 Ok(CheckWriter { writer })
1000 }
1001
1002 async fn create_lock(
1003 &self,
1004 prefix: LockPrefix<'_>,
1005 data: &[u8],
1006 options: LockOptions<'_>,
1007 ) -> Result<(String, usize)> {
1008 let kind = prefix.kind();
1009 let prefix = prefix.into();
1010 self.validate_sequential_path(&prefix)?;
1011 let (parent, _, _) = util::split_path(&prefix);
1012 let guard = LockingGuard { zk: self, prefix: &prefix, unique: kind.is_unique() };
1013 loop {
1014 let mut result = self.create(&prefix, data, &CreateMode::EphemeralSequential.with_acls(options.acls)).await;
1015 if result == Err(Error::NoNode) {
1016 if let Some(options) = &options.parent {
1017 match Self::retry_on_connection_loss(|| self.mkdir_internally(parent, options)).await {
1018 Ok(_) => continue,
1019 Err(Error::NoNode) => result = Err(Error::NoNode),
1020 Err(err) => return Err(err),
1021 }
1022 }
1023 }
1024 let sequence = match result {
1025 Err(Error::ConnectionLoss) => {
1026 if let Some(sequence) = self.find_lock(&prefix, kind).await? {
1027 sequence
1028 } else {
1029 continue;
1030 }
1031 },
1032 Err(err) => {
1033 if err.has_no_data_change() {
1034 std::mem::forget(guard);
1035 return Err(err);
1036 } else {
1037 return Err(err);
1038 }
1039 },
1040 Ok((_stat, sequence)) => sequence,
1041 };
1042 std::mem::forget(guard);
1043 let prefix_len = prefix.len();
1044 let mut path = prefix;
1045 write!(&mut path, "{sequence}").unwrap();
1046 let sequence_len = path.len() - prefix_len;
1047 return Ok((path, sequence_len));
1048 }
1049 }
1050
1051 async fn find_lock(&self, prefix: &str, kind: LockPrefixKind<'_>) -> Result<Option<CreateSequence>> {
1052 let (parent, tree, name) = util::split_path(prefix);
1053 let mut children = Self::retry_on_connection_loss(|| self.list_children(parent)).await?;
1054 if kind.is_unique() {
1055 let Some(i) = children.iter().position(|s| s.starts_with(name)) else {
1056 return Ok(None);
1057 };
1058 let sequence = Self::parse_sequence(&children[i], name)?;
1059 return Ok(Some(sequence));
1060 }
1061 children.retain(|s| s.starts_with(name));
1062 if children.is_empty() {
1063 return Ok(None);
1064 }
1065 for child in children.iter_mut() {
1066 child.insert_str(0, tree);
1067 }
1068 let results = Self::retry_on_connection_loss(|| {
1069 let mut reader = self.new_multi_reader();
1070 for child in children.iter() {
1071 reader.add_get_data(child).unwrap();
1072 }
1073 reader.commit()
1074 })
1075 .await?;
1076 for (i, result) in results.into_iter().enumerate() {
1077 let MultiReadResult::Data { stat, .. } = result else {
1078 continue;
1080 };
1081 if stat.ephemeral_owner == self.session_id().0 {
1082 let sequence = Self::parse_sequence(&children[i], name)?;
1083 return Ok(Some(sequence));
1084 }
1085 }
1086 Ok(None)
1087 }
1088
1089 async fn wait_lock(&self, lock: &str, kind: LockPrefixKind<'_>, sequence_len: usize) -> Result<()> {
1090 let (parent, tree, this) = util::split_path(lock);
1091 loop {
1092 let mut children = Self::retry_on_connection_loss(|| self.list_children(parent)).await?;
1093 children.retain(|s| {
1094 s.len() >= sequence_len && kind.filter(s) && s[s.len() - sequence_len..].parse::<i32>().is_ok()
1095 });
1096 children.sort_unstable_by(|a, b| a[a.len() - sequence_len..].cmp(&b[b.len() - sequence_len..]));
1097 match children.binary_search_by(|a| a[a.len() - sequence_len..].cmp(&this[this.len() - sequence_len..])) {
1098 Ok(0) => return Ok(()),
1099 Ok(i) => {
1100 let mut child = children.swap_remove(i - 1);
1101 child.insert_str(0, tree);
1102 let watcher = match Self::retry_on_connection_loss(|| self.get_and_watch_data(&child)).await {
1103 Err(Error::NoNode) => continue,
1104 Err(err) => return Err(err),
1105 Ok((_data, _stat, watcher)) => watcher,
1106 };
1107 watcher.changed().await;
1108 },
1109 Err(_) => return Err(Error::RuntimeInconsistent),
1110 }
1111 }
1112 }
1113
1114 pub async fn lock(
1145 &self,
1146 prefix: LockPrefix<'_>,
1147 data: &[u8],
1148 options: impl Into<LockOptions<'_>>,
1149 ) -> Result<LockClient<'_>> {
1150 let options = options.into();
1151 if options.acls.is_empty() {
1152 return Err(Error::InvalidAcl);
1153 }
1154 let prefix_kind = prefix.kind();
1155 let (lock, sequence_len) = self.create_lock(prefix, data, options).await?;
1156 let client = LockClient { client: self, lock: Cow::from(lock) };
1157 match self.wait_lock(&client.lock, prefix_kind, sequence_len).await {
1158 Err(err @ (Error::RuntimeInconsistent | Error::SessionExpired)) => {
1159 std::mem::forget(client);
1160 Err(err)
1161 },
1162 Err(err) => Err(err),
1163 Ok(_) => Ok(client),
1164 }
1165 }
1166}
1167
1168#[derive(Clone, Debug)]
1171pub struct LockOptions<'a> {
1172 acls: Acls<'a>,
1173 parent: Option<CreateOptions<'a>>,
1174}
1175
1176impl<'a> LockOptions<'a> {
1177 pub fn new(acls: Acls<'a>) -> Self {
1178 Self { acls, parent: None }
1179 }
1180
1181 pub fn with_ancestor_options(mut self, options: CreateOptions<'a>) -> Result<Self> {
1187 options.validate_as_directory()?;
1188 self.parent = Some(options);
1189 Ok(self)
1190 }
1191}
1192
1193impl<'a> From<Acls<'a>> for LockOptions<'a> {
1194 fn from(acls: Acls<'a>) -> Self {
1195 LockOptions::new(acls)
1196 }
1197}
1198
1199#[derive(Clone, Copy)]
1200enum LockPrefixKind<'a> {
1201 Curator { lock_name: &'a str },
1202 Custom { lock_name: &'a str },
1203 Shared { prefix: &'a str },
1204}
1205
1206impl LockPrefixKind<'_> {
1207 fn filter(&self, name: &str) -> bool {
1208 match self {
1209 Self::Curator { lock_name } => name.contains(lock_name),
1210 Self::Custom { lock_name } => name.contains(lock_name),
1211 Self::Shared { prefix } => name.starts_with(prefix),
1212 }
1213 }
1214
1215 fn is_unique(&self) -> bool {
1216 matches!(self, Self::Curator { .. })
1217 }
1218}
1219
1220#[derive(Debug)]
1221enum LockPrefixInner<'a> {
1222 Curator { dir: &'a str, name: &'a str },
1223 Custom { prefix: String, name: &'a str },
1224 Shared { prefix: &'a str },
1225}
1226
1227#[derive(Debug)]
1236pub struct LockPrefix<'a> {
1237 inner: LockPrefixInner<'a>,
1238}
1239
1240impl<'a> LockPrefix<'a> {
1241 pub fn new_curator(dir: &'a str, name: &'a str) -> Result<Self> {
1248 crate::util::validate_path(Chroot::default(), dir, false)?;
1249 if name.find('/').is_some() {
1250 return Err(Error::BadArguments(&"lock name must not contain /"));
1251 }
1252 Ok(Self { inner: LockPrefixInner::Curator { dir, name } })
1253 }
1254
1255 pub fn new_shared(prefix: &'a str) -> Result<Self> {
1267 crate::util::validate_path(Chroot::default(), prefix, true)?;
1268 Ok(Self { inner: LockPrefixInner::Shared { prefix } })
1269 }
1270
1271 pub fn new_custom(prefix: String, name: &'a str) -> Result<Self> {
1287 crate::util::validate_path(Chroot::default(), &prefix, true)?;
1288 if !name.is_empty() {
1289 let (_dir, _tree, this) = util::split_path(&prefix);
1290 if !this.contains(name) {
1291 return Err(Error::BadArguments(&"lock path prefix must contain lock name"));
1292 }
1293 }
1294 Ok(Self { inner: LockPrefixInner::Custom { prefix, name } })
1295 }
1296
1297 fn kind(&self) -> LockPrefixKind<'a> {
1298 match &self.inner {
1299 LockPrefixInner::Curator { name, .. } => LockPrefixKind::Curator { lock_name: name },
1300 LockPrefixInner::Shared { prefix } => {
1301 let (_parent, _tree, name) = util::split_path(prefix);
1302 LockPrefixKind::Shared { prefix: name }
1303 },
1304 LockPrefixInner::Custom { name, .. } => LockPrefixKind::Custom { lock_name: name },
1305 }
1306 }
1307
1308 fn into(self) -> String {
1309 match self.inner {
1310 LockPrefixInner::Curator { dir, name } => format!("{}/_c_{}-{}", dir, uuid::Uuid::new_v4(), name),
1311 LockPrefixInner::Shared { prefix } => prefix.to_string(),
1312 LockPrefixInner::Custom { prefix, .. } => prefix,
1313 }
1314 }
1315}
1316
1317struct LockingGuard<'a> {
1318 zk: &'a Client,
1319 prefix: &'a str,
1320 unique: bool,
1321}
1322
1323impl Drop for LockingGuard<'_> {
1324 fn drop(&mut self) {
1325 self.zk.clone().delete_ephemeral_background(self.prefix.to_string(), self.unique);
1326 }
1327}
1328
1329#[derive(Debug)]
1331pub struct LockClient<'a> {
1332 client: &'a Client,
1333 lock: Cow<'a, str>,
1334}
1335
1336impl<'a> LockClient<'a> {
1337 async fn resolve_one_write(
1338 future: impl Future<Output = std::result::Result<Vec<MultiWriteResult>, CheckWriteError>>,
1339 ) -> Result<MultiWriteResult> {
1340 let mut results = future.await?;
1341 Ok(results.remove(0))
1342 }
1343
1344 pub fn client(&self) -> &'a Client {
1346 self.client
1347 }
1348
1349 pub fn lock_path(&self) -> &str {
1354 &self.lock
1355 }
1356
1357 pub fn create(
1365 &self,
1366 path: &str,
1367 data: &[u8],
1368 options: &CreateOptions<'_>,
1369 ) -> impl Future<Output = Result<(Stat, CreateSequence)>> + Send + 'a {
1370 Client::wait(self.create_internally(path, data, options))
1371 }
1372
1373 fn create_internally(
1374 &self,
1375 path: &str,
1376 data: &[u8],
1377 options: &CreateOptions<'_>,
1378 ) -> Result<impl Future<Output = Result<(Stat, CreateSequence)>> + Send + 'a> {
1379 let mut writer = self.client.new_check_writer(&self.lock, None)?;
1380 writer.add_create(path, data, options)?;
1381 let write = writer.commit();
1382 let path_len = path.len();
1387 Ok(async move {
1388 let result = Self::resolve_one_write(write).await?;
1389 let (created_path, stat) = result.into_create()?;
1390 let sequence = if created_path.len() <= path_len {
1391 CreateSequence(-1)
1392 } else {
1393 Client::parse_sequence(&created_path, &created_path[..path_len])?
1394 };
1395 Ok((stat, sequence))
1396 })
1397 }
1398
1399 pub fn set_data(
1401 &self,
1402 path: &str,
1403 data: &[u8],
1404 expected_version: Option<i32>,
1405 ) -> impl Future<Output = Result<Stat>> + Send + 'a {
1406 Client::wait(self.set_data_internally(path, data, expected_version))
1407 }
1408
1409 fn set_data_internally(
1410 &self,
1411 path: &str,
1412 data: &[u8],
1413 expected_version: Option<i32>,
1414 ) -> Result<impl Future<Output = Result<Stat>> + Send + 'a> {
1415 let mut writer = self.new_check_writer();
1416 writer.add_set_data(path, data, expected_version)?;
1417 let write = writer.commit();
1418 Ok(async move {
1419 let result = Self::resolve_one_write(write).await?;
1420 let stat = result.into_set_data()?;
1421 Ok(stat)
1422 })
1423 }
1424
1425 pub fn delete(&self, path: &str, expected_version: Option<i32>) -> impl Future<Output = Result<()>> + Send + 'a {
1427 Client::wait(self.delete_internally(path, expected_version))
1428 }
1429
1430 fn delete_internally(
1431 &self,
1432 path: &str,
1433 expected_version: Option<i32>,
1434 ) -> Result<impl Future<Output = Result<()>> + Send + 'a> {
1435 let mut writer = self.new_check_writer();
1436 writer.add_delete(path, expected_version)?;
1437 let write = writer.commit();
1438 Ok(async move {
1439 let result = Self::resolve_one_write(write).await?;
1440 result.into_delete()
1441 })
1442 }
1443
1444 pub fn new_check_writer(&self) -> CheckWriter<'a> {
1446 unsafe { self.client.new_check_writer(&self.lock, None).unwrap_unchecked() }
1447 }
1448
1449 pub fn into_owned(self) -> OwnedLockClient {
1451 let client = self.client.clone();
1452 let mut drop = ManuallyDrop::new(self);
1453 let lock = std::mem::take(drop.lock.to_mut());
1454 OwnedLockClient { client: ManuallyDrop::new(client), lock }
1455 }
1456}
1457
1458impl Drop for LockClient<'_> {
1460 fn drop(&mut self) {
1461 let path = std::mem::take(self.lock.to_mut());
1462 let client = self.client.clone();
1463 client.delete_background(path);
1464 }
1465}
1466
1467#[derive(Clone, Debug)]
1469pub struct OwnedLockClient {
1470 client: ManuallyDrop<Client>,
1471 lock: String,
1472}
1473
1474impl OwnedLockClient {
1475 fn lock_client(&self) -> std::mem::ManuallyDrop<LockClient<'_>> {
1476 std::mem::ManuallyDrop::new(LockClient { client: &self.client, lock: Cow::from(&self.lock) })
1477 }
1478
1479 pub fn client(&self) -> &Client {
1481 &self.client
1482 }
1483
1484 pub fn lock_path(&self) -> &str {
1486 &self.lock
1487 }
1488
1489 pub fn create<'a: 'f, 'b: 'f, 'f>(
1491 &'a self,
1492 path: &'b str,
1493 data: &[u8],
1494 options: &CreateOptions<'_>,
1495 ) -> impl Future<Output = Result<(Stat, CreateSequence)>> + Send + 'f {
1496 self.lock_client().create(path, data, options)
1497 }
1498
1499 pub fn set_data(
1501 &self,
1502 path: &str,
1503 data: &[u8],
1504 expected_version: Option<i32>,
1505 ) -> impl Future<Output = Result<Stat>> + Send + '_ {
1506 self.lock_client().set_data(path, data, expected_version)
1507 }
1508
1509 pub fn delete(&self, path: &str, expected_version: Option<i32>) -> impl Future<Output = Result<()>> + Send + '_ {
1511 self.lock_client().delete(path, expected_version)
1512 }
1513
1514 pub fn new_check_writer(&self) -> CheckWriter<'_> {
1516 unsafe { self.client.new_check_writer(&self.lock, None).unwrap_unchecked() }
1517 }
1518}
1519
1520impl Drop for OwnedLockClient {
1522 fn drop(&mut self) {
1523 let client = unsafe { ManuallyDrop::take(&mut self.client) };
1524 let path = std::mem::take(&mut self.lock);
1525 client.delete_background(path);
1526 }
1527}
1528
1529#[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
1530pub(crate) struct Version(u32, u32, u32);
1531
1532#[derive(Clone)]
1536#[derive_where(Debug)]
1537pub struct Connector {
1538 #[cfg(feature = "tls")]
1539 tls: Option<TlsOptions>,
1540 #[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
1541 sasl: Option<SaslOptions>,
1542 #[derive_where(skip(Debug))]
1543 authes: Vec<MarshalledRequest>,
1544 session: Option<SessionInfo>,
1545 readonly: bool,
1546 detached: bool,
1547 fail_eagerly: bool,
1548 server_version: Version,
1549 session_timeout: Duration,
1550 connection_timeout: Duration,
1551}
1552
1553impl Connector {
1554 fn new() -> Self {
1555 Self {
1556 #[cfg(feature = "tls")]
1557 tls: None,
1558 #[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
1559 sasl: None,
1560 authes: Default::default(),
1561 session: None,
1562 readonly: false,
1563 detached: false,
1564 fail_eagerly: false,
1565 server_version: Version(u32::MAX, u32::MAX, u32::MAX),
1566 session_timeout: Duration::ZERO,
1567 connection_timeout: Duration::ZERO,
1568 }
1569 }
1570
1571 pub fn with_session_timeout(mut self, timeout: Duration) -> Self {
1575 self.session_timeout = timeout;
1576 self
1577 }
1578
1579 #[deprecated(since = "0.11.0", note = "use Connector::with_session_timeout instead")]
1583 pub fn session_timeout(&mut self, timeout: Duration) -> &mut Self {
1584 self.session_timeout = timeout;
1585 self
1586 }
1587
1588 pub fn with_connection_timeout(mut self, timeout: Duration) -> Self {
1592 self.connection_timeout = timeout;
1593 self
1594 }
1595
1596 #[deprecated(since = "0.11.0", note = "use Connector::with_connection_timeout instead")]
1600 pub fn connection_timeout(&mut self, timeout: Duration) -> &mut Self {
1601 self.connection_timeout = timeout;
1602 self
1603 }
1604
1605 pub fn with_readonly(mut self, readonly: bool) -> Self {
1607 self.readonly = readonly;
1608 self
1609 }
1610
1611 #[deprecated(since = "0.11.0", note = "use Connector::with_readonly instead")]
1613 pub fn readonly(&mut self, readonly: bool) -> &mut Self {
1614 self.readonly = readonly;
1615 self
1616 }
1617
1618 pub fn with_auth(mut self, scheme: &str, auth: &[u8]) -> Self {
1620 let packet = AuthPacket { scheme, auth };
1621 let request = MarshalledRequest::new(OpCode::Auth, &packet);
1622 self.authes.push(request);
1623 self
1624 }
1625
1626 #[deprecated(since = "0.11.0", note = "use Connector::with_auth instead")]
1628 pub fn auth(&mut self, scheme: String, auth: Vec<u8>) -> &mut Self {
1629 let packet = AuthPacket { scheme: &scheme, auth: &auth };
1630 let request = MarshalledRequest::new(OpCode::Auth, &packet);
1631 self.authes.push(request);
1632 self
1633 }
1634
1635 pub fn with_session(mut self, session: SessionInfo) -> Self {
1637 self.session = Some(session);
1638 self
1639 }
1640
1641 #[deprecated(since = "0.11.0", note = "use Connector::with_session instead")]
1643 pub fn session(&mut self, session: SessionInfo) -> &mut Self {
1644 self.session = Some(session);
1645 self
1646 }
1647
1648 pub fn with_server_version(mut self, major: u32, minor: u32, patch: u32) -> Self {
1658 self.server_version = Version(major, minor, patch);
1659 self
1660 }
1661
1662 #[deprecated(since = "0.11.0", note = "use Connector::with_server_version instead")]
1672 pub fn server_version(&mut self, major: u32, minor: u32, patch: u32) -> &mut Self {
1673 self.server_version = Version(major, minor, patch);
1674 self
1675 }
1676
1677 pub fn with_detached(mut self) -> Self {
1679 self.detached = true;
1680 self
1681 }
1682
1683 #[deprecated(since = "0.11.0", note = "use Connector::with_detached instead")]
1685 pub fn detached(&mut self) -> &mut Self {
1686 self.detached = true;
1687 self
1688 }
1689
1690 #[cfg(feature = "tls")]
1692 #[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
1693 pub fn with_tls(mut self, options: TlsOptions) -> Self {
1694 self.tls = Some(options);
1695 self
1696 }
1697
1698 #[cfg(feature = "tls")]
1700 #[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
1701 #[deprecated(since = "0.11.0", note = "use Connector::with_tls instead")]
1702 pub fn tls(&mut self, options: TlsOptions) -> &mut Self {
1703 self.tls = Some(options);
1704 self
1705 }
1706
1707 #[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
1709 #[cfg_attr(docsrs, doc(cfg(any(feature = "sasl", feature = "sasl-gssapi", feature = "sasl-digest-md5"))))]
1710 pub fn with_sasl(mut self, options: impl Into<SaslOptions>) -> Self {
1711 self.sasl = Some(options.into());
1712 self
1713 }
1714
1715 #[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
1717 #[cfg_attr(docsrs, doc(cfg(any(feature = "sasl", feature = "sasl-gssapi", feature = "sasl-digest-md5"))))]
1718 #[deprecated(since = "0.11.0", note = "use Connector::with_sasl instead")]
1719 pub fn sasl(&mut self, options: impl Into<SaslOptions>) -> &mut Self {
1720 self.sasl = Some(options.into());
1721 self
1722 }
1723
1724 pub fn with_fail_eagerly(mut self) -> Self {
1729 self.fail_eagerly = true;
1730 self
1731 }
1732
1733 #[deprecated(since = "0.11.0", note = "use Connector::with_fail_eagerly instead")]
1738 pub fn fail_eagerly(&mut self) -> &mut Self {
1739 self.fail_eagerly = true;
1740 self
1741 }
1742
1743 #[instrument(name = "connect", skip_all, fields(session))]
1744 async fn connect_internally(self, secure: bool, cluster: &str) -> Result<Client> {
1745 let (endpoints, chroot) = endpoint::parse_connect_string(cluster, secure)?;
1746 let builder = Session::builder()
1747 .with_session(self.session)
1748 .with_authes(self.authes)
1749 .with_readonly(self.readonly)
1750 .with_detached(self.detached)
1751 .with_session_timeout(self.session_timeout)
1752 .with_connection_timeout(self.connection_timeout);
1753 #[cfg(feature = "tls")]
1754 let builder = builder.with_tls(self.tls);
1755 #[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
1756 let builder = builder.with_sasl(self.sasl);
1757 let (sender, receiver) = mpsc::unbounded();
1758 let sender = Arc::new(sender);
1759 let mut session = builder.build(Arc::downgrade(&sender))?;
1760 let mut endpoints = IterableEndpoints::from(endpoints.as_slice());
1761 endpoints.reset();
1762 if !self.fail_eagerly {
1763 endpoints.cycle();
1764 }
1765 let mut buf = Vec::with_capacity(4096);
1766 let mut depot = Depot::new();
1767 let conn = session.start(&mut endpoints, &mut buf, &mut depot).await?;
1768 let session_info = session.session.clone();
1769 let session_timeout = session.session_timeout;
1770 let mut state_watcher = StateWatcher::new(session.subscribe_state());
1771 state_watcher.state();
1773 asyncs::spawn(async move {
1774 session.serve(endpoints, conn, buf, depot, receiver).await;
1775 });
1776 let client =
1777 Client::new(chroot.to_owned(), self.server_version, session_info, session_timeout, sender, state_watcher);
1778 Ok(client)
1779 }
1780
1781 #[cfg(feature = "tls")]
1786 pub async fn secure_connect(self, cluster: &str) -> Result<Client> {
1787 self.connect_internally(true, cluster).await
1788 }
1789
1790 pub async fn connect(self, cluster: &str) -> Result<Client> {
1802 self.connect_internally(false, cluster).await
1803 }
1804}
1805
1806trait MultiBuffer {
1807 fn buffer(&mut self) -> &mut Vec<u8>;
1808
1809 fn op_code() -> OpCode;
1810
1811 fn build_request(&mut self) -> MarshalledRequest {
1812 let buffer = self.buffer();
1813 if buffer.is_empty() {
1814 return Default::default();
1815 }
1816 let header = MultiHeader { op: OpCode::Error, done: true, err: -1 };
1817 buffer.append_record(&header);
1818 buffer.finish();
1819 MarshalledRequest(std::mem::take(buffer))
1820 }
1821
1822 fn add_operation(&mut self, op: OpCode, request: &impl Record) {
1823 let buffer = self.buffer();
1824 if buffer.is_empty() {
1825 let n = RequestHeader::record_len() + MultiHeader::record_len() + request.serialized_len();
1826 buffer.prepare_and_reserve(n);
1827 buffer.append_record(&RequestHeader::with_code(Self::op_code()));
1828 }
1829 let header = MultiHeader { op, done: false, err: -1 };
1830 self.buffer().append_record2(&header, request);
1831 }
1832}
1833
1834#[non_exhaustive]
1836#[derive(Debug)]
1837pub enum MultiReadResult {
1838 Data { data: Vec<u8>, stat: Stat },
1840
1841 Children { children: Vec<String> },
1843
1844 Error { err: Error },
1846}
1847
1848pub struct MultiReader<'a> {
1850 client: &'a Client,
1851 buf: Vec<u8>,
1852}
1853
1854impl MultiBuffer for MultiReader<'_> {
1855 fn buffer(&mut self) -> &mut Vec<u8> {
1856 &mut self.buf
1857 }
1858
1859 fn op_code() -> OpCode {
1860 OpCode::MultiRead
1861 }
1862}
1863
1864impl<'a> MultiReader<'a> {
1865 fn new(client: &'a Client) -> MultiReader<'a> {
1866 MultiReader { client, buf: Default::default() }
1867 }
1868
1869 pub fn add_get_data(&mut self, path: &str) -> Result<()> {
1873 let chroot_path = self.client.validate_path(path)?;
1874 let request = GetRequest { path: chroot_path, watch: false };
1875 self.add_operation(OpCode::GetData, &request);
1876 Ok(())
1877 }
1878
1879 pub fn add_get_children(&mut self, path: &str) -> Result<()> {
1883 let chroot_path = self.client.validate_path(path)?;
1884 let request = GetChildrenRequest { path: chroot_path, watch: false };
1885 self.add_operation(OpCode::GetChildren, &request);
1886 Ok(())
1887 }
1888
1889 pub fn commit(&mut self) -> impl Future<Output = Result<Vec<MultiReadResult>>> + Send + 'a {
1894 let request = self.build_request();
1895 Client::resolve(self.commit_internally(request))
1896 }
1897
1898 fn commit_internally(
1899 &self,
1900 request: MarshalledRequest,
1901 ) -> Result<Either<impl Future<Output = Result<Vec<MultiReadResult>>> + Send + 'a, Vec<MultiReadResult>>> {
1902 if request.is_empty() {
1903 return Ok(Right(Vec::default()));
1904 }
1905 let receiver = self.client.send_marshalled_request(request);
1906 Ok(Left(async move {
1907 let (body, _) = receiver.await?;
1908 let response = record::unmarshal::<Vec<MultiReadResponse>>(&mut body.as_slice())?;
1909 let mut results = Vec::with_capacity(response.len());
1910 for result in response {
1911 match result {
1912 MultiReadResponse::Data { data, stat } => results.push(MultiReadResult::Data { data, stat }),
1913 MultiReadResponse::Children { children } => results.push(MultiReadResult::Children { children }),
1914 MultiReadResponse::Error(err) => results.push(MultiReadResult::Error { err }),
1915 }
1916 }
1917 Ok(results)
1918 }))
1919 }
1920
1921 pub fn abort(&mut self) {
1923 self.buf.clear();
1924 }
1925}
1926
1927#[non_exhaustive]
1929#[derive(Debug, PartialEq, Eq)]
1930pub enum MultiWriteResult {
1931 Check,
1933
1934 Delete,
1936
1937 Create {
1939 path: String,
1941
1942 stat: Stat,
1949 },
1950
1951 SetData {
1953 stat: Stat,
1955 },
1956}
1957
1958impl MultiWriteResult {
1959 fn kind(&self) -> &'static str {
1960 match self {
1961 MultiWriteResult::Check => "MultiWriteResult::Check",
1962 MultiWriteResult::Create { .. } => "MultiWriteResult::Create",
1963 MultiWriteResult::Delete => "MultiWriteResult::Delete",
1964 MultiWriteResult::SetData { .. } => "MultiWriteResult::SetData",
1965 }
1966 }
1967
1968 fn into_check(self) -> Result<()> {
1969 match self {
1970 MultiWriteResult::Check => Ok(()),
1971 _ => Err(Error::UnexpectedError(format!("expect MultiWriteResult::Check, got {}", self.kind()))),
1972 }
1973 }
1974
1975 fn into_create(self) -> Result<(String, Stat)> {
1976 match self {
1977 MultiWriteResult::Create { path, stat } => Ok((path, stat)),
1978 _ => Err(Error::UnexpectedError(format!("expect MultiWriteResult::Create, got {}", self.kind()))),
1979 }
1980 }
1981
1982 fn into_set_data(self) -> Result<Stat> {
1983 match self {
1984 MultiWriteResult::SetData { stat } => Ok(stat),
1985 _ => Err(Error::UnexpectedError(format!("expect MultiWriteResult::SetData, got {}", self.kind()))),
1986 }
1987 }
1988
1989 fn into_delete(self) -> Result<()> {
1990 match self {
1991 MultiWriteResult::Delete => Ok(()),
1992 _ => Err(Error::UnexpectedError(format!("expect MultiWriteResult::Delete, got {}", self.kind()))),
1993 }
1994 }
1995}
1996
1997#[derive(Error, Clone, Debug, PartialEq, Eq)]
1999pub enum MultiWriteError {
2000 #[error("{source}")]
2001 RequestFailed {
2002 #[from]
2003 source: Error,
2004 },
2005
2006 #[error("operation at index {index} failed: {source}")]
2007 OperationFailed { index: usize, source: Error },
2008}
2009
2010impl From<MultiWriteError> for Error {
2011 fn from(err: MultiWriteError) -> Self {
2012 match err {
2013 MultiWriteError::RequestFailed { source } => source,
2014 MultiWriteError::OperationFailed { source, .. } => source,
2015 }
2016 }
2017}
2018
2019#[derive(Error, Clone, Debug, PartialEq, Eq)]
2021pub enum CheckWriteError {
2022 #[error("request failed: {source}")]
2023 RequestFailed {
2024 #[from]
2025 source: Error,
2026 },
2027
2028 #[error("path check failed: {source}")]
2029 CheckFailed { source: Error },
2030
2031 #[error("operation at index {index} failed: {source}")]
2032 OperationFailed { index: usize, source: Error },
2033}
2034
2035impl From<MultiWriteError> for CheckWriteError {
2036 fn from(err: MultiWriteError) -> Self {
2037 match err {
2038 MultiWriteError::RequestFailed { source } => CheckWriteError::RequestFailed { source },
2039 MultiWriteError::OperationFailed { index: 0, source } => CheckWriteError::CheckFailed { source },
2040 MultiWriteError::OperationFailed { index, source } => {
2041 CheckWriteError::OperationFailed { index: index - 1, source }
2042 },
2043 }
2044 }
2045}
2046
2047impl From<CheckWriteError> for Error {
2048 fn from(err: CheckWriteError) -> Self {
2049 match err {
2050 CheckWriteError::RequestFailed { source } => source,
2051 CheckWriteError::CheckFailed { source: Error::NoNode | Error::BadVersion } => Error::RuntimeInconsistent,
2052 CheckWriteError::CheckFailed { source } => source,
2053 CheckWriteError::OperationFailed { source, .. } => source,
2054 }
2055 }
2056}
2057
2058pub struct CheckWriter<'a> {
2060 writer: MultiWriter<'a>,
2061}
2062
2063impl<'a> CheckWriter<'a> {
2064 pub fn add_check_version(&mut self, path: &str, version: i32) -> Result<()> {
2066 self.writer.add_check_version(path, version)
2067 }
2068
2069 pub fn add_create(&mut self, path: &str, data: &[u8], options: &CreateOptions<'_>) -> Result<()> {
2071 self.writer.add_create(path, data, options)
2072 }
2073
2074 pub fn add_set_data(&mut self, path: &str, data: &[u8], expected_version: Option<i32>) -> Result<()> {
2076 self.writer.add_set_data(path, data, expected_version)
2077 }
2078
2079 pub fn add_delete(&mut self, path: &str, expected_version: Option<i32>) -> Result<()> {
2081 self.writer.add_delete(path, expected_version)
2082 }
2083
2084 pub fn commit(
2086 mut self,
2087 ) -> impl Future<Output = std::result::Result<Vec<MultiWriteResult>, CheckWriteError>> + Send + 'a {
2088 let commit = self.writer.commit();
2089 async move {
2090 let mut results = commit.await?;
2091 if results.is_empty() {
2092 Err(CheckWriteError::RequestFailed {
2093 source: Error::UnexpectedError("expect path check, got none".to_string()),
2094 })
2095 } else {
2096 results.remove(0).into_check()?;
2097 Ok(results)
2098 }
2099 }
2100 }
2101}
2102
2103pub struct MultiWriter<'a> {
2105 client: &'a Client,
2106 buf: Vec<u8>,
2107}
2108
2109impl MultiBuffer for MultiWriter<'_> {
2110 fn buffer(&mut self) -> &mut Vec<u8> {
2111 &mut self.buf
2112 }
2113
2114 fn op_code() -> OpCode {
2115 OpCode::Multi
2116 }
2117}
2118
2119impl<'a> MultiWriter<'a> {
2120 fn new(client: &'a Client) -> MultiWriter<'a> {
2121 MultiWriter { client, buf: Default::default() }
2122 }
2123
2124 pub fn add_check_version(&mut self, path: &str, version: i32) -> Result<()> {
2129 let chroot_path = self.client.validate_path(path)?;
2130 let request = CheckVersionRequest { path: chroot_path, version };
2131 self.add_operation(OpCode::Check, &request);
2132 Ok(())
2133 }
2134
2135 pub fn add_create(&mut self, path: &str, data: &[u8], options: &CreateOptions<'_>) -> Result<()> {
2146 options.validate()?;
2147 let ttl = options.ttl.map(|ttl| ttl.as_millis() as i64).unwrap_or(0);
2148 let create_mode = options.mode;
2149 let sequential = create_mode.is_sequential();
2150 let chroot_path =
2151 if sequential { self.client.validate_sequential_path(path)? } else { self.client.validate_path(path)? };
2152 let op_code = if ttl != 0 {
2153 OpCode::CreateTtl
2154 } else if create_mode.is_container() {
2155 OpCode::CreateContainer
2156 } else {
2157 OpCode::Create2
2158 };
2159 let flags = create_mode.as_flags(ttl != 0);
2160 let request = CreateRequest { path: chroot_path, data, acls: options.acls, flags, ttl };
2161 self.add_operation(op_code, &request);
2162 Ok(())
2163 }
2164
2165 pub fn add_set_data(&mut self, path: &str, data: &[u8], expected_version: Option<i32>) -> Result<()> {
2169 let chroot_path = self.client.validate_path(path)?;
2170 let request = SetDataRequest { path: chroot_path, data, version: expected_version.unwrap_or(-1) };
2171 self.add_operation(OpCode::SetData, &request);
2172 Ok(())
2173 }
2174
2175 pub fn add_delete(&mut self, path: &str, expected_version: Option<i32>) -> Result<()> {
2179 let chroot_path = self.client.validate_path(path)?;
2180 if chroot_path.is_root() {
2181 return Err(Error::BadArguments(&"can not delete root node"));
2182 }
2183 let request = DeleteRequest { path: chroot_path, version: expected_version.unwrap_or(-1) };
2184 self.add_operation(OpCode::Delete, &request);
2185 Ok(())
2186 }
2187
2188 pub fn commit(
2196 &mut self,
2197 ) -> impl Future<Output = std::result::Result<Vec<MultiWriteResult>, MultiWriteError>> + Send + 'a {
2198 let request = self.build_request();
2199 Client::resolve(self.commit_internally(request))
2200 }
2201
2202 #[allow(clippy::type_complexity)]
2203 fn commit_internally(
2204 &self,
2205 request: MarshalledRequest,
2206 ) -> Result<
2207 Either<impl Future<Output = Result<Vec<MultiWriteResult>, MultiWriteError>> + Send + 'a, Vec<MultiWriteResult>>,
2208 MultiWriteError,
2209 > {
2210 if request.is_empty() {
2211 return Ok(Right(Vec::default()));
2212 }
2213 let receiver = self.client.send_marshalled_request(request);
2214 let client = self.client;
2215 Ok(Left(async move {
2216 let (body, _) = receiver.await?;
2217 let response = record::unmarshal::<Vec<MultiWriteResponse>>(&mut body.as_slice())?;
2218 let failed = response.first().map(|r| matches!(r, MultiWriteResponse::Error(_))).unwrap_or(false);
2219 let mut results = if failed { Vec::new() } else { Vec::with_capacity(response.len()) };
2220 for (index, result) in response.into_iter().enumerate() {
2221 match result {
2222 MultiWriteResponse::Check => results.push(MultiWriteResult::Check),
2223 MultiWriteResponse::Delete => results.push(MultiWriteResult::Delete),
2224 MultiWriteResponse::Create { mut path, stat } => {
2225 path = util::strip_root_path(path, client.chroot.root())?;
2226 results.push(MultiWriteResult::Create { path: path.to_string(), stat });
2227 },
2228 MultiWriteResponse::SetData { stat } => results.push(MultiWriteResult::SetData { stat }),
2229 MultiWriteResponse::Error(Error::UnexpectedErrorCode(0)) => {},
2230 MultiWriteResponse::Error(err) => {
2231 return Err(MultiWriteError::OperationFailed { index, source: err })
2232 },
2233 }
2234 }
2235 Ok(results)
2236 }))
2237 }
2238
2239 pub fn abort(&mut self) {
2241 self.buf.clear();
2242 }
2243}
2244
2245#[cfg(test)]
2246mod tests {
2247 use assertor::*;
2248
2249 use super::*;
2250
2251 #[test]
2252 fn test_create_options_validate() {
2253 assert_that!(CreateMode::Persistent.with_acls(Acls::new(Default::default())).validate().unwrap_err())
2254 .is_equal_to(Error::InvalidAcl);
2255
2256 let acls = Acls::anyone_all();
2257
2258 assert_that!(CreateMode::Ephemeral.with_acls(acls).with_ttl(Duration::from_secs(1)).validate().unwrap_err())
2259 .is_equal_to(Error::BadArguments(&"ttl can only be specified with persistent node"));
2260
2261 assert_that!(CreateMode::Persistent.with_acls(acls).with_ttl(Duration::ZERO).validate().unwrap_err())
2262 .is_equal_to(Error::BadArguments(&"ttl is zero"));
2263
2264 assert_that!(CreateMode::Persistent
2265 .with_acls(acls)
2266 .with_ttl(Duration::from_millis(0x01FFFFFFFFFF))
2267 .validate()
2268 .unwrap_err())
2269 .is_equal_to(Error::BadArguments(&"ttl cannot larger than 1099511627775"));
2270
2271 assert_that!(CreateMode::Persistent.with_acls(acls).with_ttl(Duration::from_secs(5)).validate())
2272 .is_equal_to(Ok(()));
2273 }
2274
2275 #[test]
2276 fn test_lock_options_with_ancestor_options() {
2277 let options = LockOptions::new(Acls::anyone_all());
2278 assert_that!(options
2279 .clone()
2280 .with_ancestor_options(CreateMode::Ephemeral.with_acls(Acls::anyone_all()))
2281 .unwrap_err())
2282 .is_equal_to(Error::BadArguments(&"directory node must not be ephemeral"));
2283 assert_that!(options
2284 .with_ancestor_options(CreateMode::PersistentSequential.with_acls(Acls::anyone_all()))
2285 .unwrap_err())
2286 .is_equal_to(Error::BadArguments(&"directory node must not be sequential"));
2287 }
2288
2289 #[test_log::test(asyncs::test)]
2290 async fn session_last_zxid_seen() {
2291 use testcontainers::clients::Cli as DockerCli;
2292 use testcontainers::core::{Healthcheck, WaitFor};
2293 use testcontainers::images::generic::GenericImage;
2294
2295 let healthcheck = Healthcheck::default()
2296 .with_cmd(["./bin/zkServer.sh", "status"].iter())
2297 .with_interval(Duration::from_secs(2))
2298 .with_retries(60);
2299 let image =
2300 GenericImage::new("zookeeper", "3.9.0").with_healthcheck(healthcheck).with_wait_for(WaitFor::Healthcheck);
2301 let docker = DockerCli::default();
2302 let container = docker.run(image);
2303 let endpoint = format!("127.0.0.1:{}", container.get_host_port(2181));
2304
2305 let client1 = Client::connector().with_detached().connect(&endpoint).await.unwrap();
2306 client1.create("/n1", b"", &CreateMode::Persistent.with_acls(Acls::anyone_all())).await.unwrap();
2307
2308 let mut session = client1.into_session();
2309
2310 session.last_zxid = i64::MAX;
2312 assert_that!(Client::connector()
2313 .with_fail_eagerly()
2314 .with_session(session.clone())
2315 .connect(&endpoint)
2316 .await
2317 .unwrap_err())
2318 .is_equal_to(Error::NoHosts);
2319
2320 session.last_zxid = 0;
2322 let client2 =
2323 Client::connector().with_fail_eagerly().with_session(session.clone()).connect(&endpoint).await.unwrap();
2324 client2.create("/n2", b"", &CreateMode::Persistent.with_acls(Acls::anyone_all())).await.unwrap();
2325 }
2326}