1mod watcher;
2
3use std::borrow::Cow;
4use std::fmt::Write as _;
5use std::future::Future;
6use std::mem::ManuallyDrop;
7use std::time::Duration;
8
9use const_format::formatcp;
10use either::{Either, Left, Right};
11use futures::channel::mpsc;
12use ignore_result::Ignore;
13use thiserror::Error;
14use tracing::instrument;
15
16pub use self::watcher::{OneshotWatcher, PersistentWatcher, StateWatcher};
17use super::session::{Depot, MarshalledRequest, Session, SessionOperation, WatchReceiver};
18use crate::acl::{Acl, Acls, AuthUser};
19use crate::chroot::{Chroot, ChrootPath, OwnedChroot};
20use crate::endpoint::{self, IterableEndpoints};
21use crate::error::Error;
22use crate::proto::{
23 self,
24 AuthPacket,
25 CheckVersionRequest,
26 CreateRequest,
27 DeleteRequest,
28 ExistsRequest,
29 GetAclResponse,
30 GetChildren2Response,
31 GetChildrenRequest,
32 GetRequest,
33 MultiHeader,
34 MultiReadResponse,
35 MultiWriteResponse,
36 OpCode,
37 PersistentWatchRequest,
38 ReconfigRequest,
39 RequestBuffer,
40 RequestHeader,
41 SetAclRequest,
42 SetDataRequest,
43 SyncRequest,
44};
45pub use crate::proto::{EnsembleUpdate, Stat};
46use crate::record::{self, Record, StaticRecord};
47#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
48use crate::sasl::SaslOptions;
49use crate::session::StateReceiver;
50pub use crate::session::{EventType, SessionId, SessionInfo, SessionState, WatchedEvent};
51#[cfg(feature = "tls")]
52use crate::tls::TlsOptions;
53use crate::util;
54
55pub(crate) type Result<T, E = Error> = std::result::Result<T, E>;
56
57#[derive(Clone, Copy, Debug, PartialEq, Eq)]
60pub enum CreateMode {
61 Persistent,
62 PersistentSequential,
63 Ephemeral,
64 EphemeralSequential,
65 Container,
66}
67
68impl CreateMode {
69 pub const fn with_acls(self, acls: Acls<'_>) -> CreateOptions<'_> {
71 CreateOptions { mode: self, acls, ttl: None }
72 }
73
74 fn is_sequential(self) -> bool {
75 self == CreateMode::PersistentSequential || self == CreateMode::EphemeralSequential
76 }
77
78 fn is_persistent(self) -> bool {
79 self == Self::Persistent || self == Self::PersistentSequential
80 }
81
82 fn is_ephemeral(self) -> bool {
83 self == Self::Ephemeral || self == Self::EphemeralSequential
84 }
85
86 fn is_container(self) -> bool {
87 self == CreateMode::Container
88 }
89
90 fn as_flags(self, ttl: bool) -> i32 {
91 use CreateMode::*;
92 match self {
93 Persistent => {
94 if ttl {
95 5
96 } else {
97 0
98 }
99 },
100 PersistentSequential => {
101 if ttl {
102 6
103 } else {
104 2
105 }
106 },
107 Ephemeral => 1,
108 EphemeralSequential => 3,
109 Container => 4,
110 }
111 }
112}
113
114#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
116pub enum AddWatchMode {
117 Persistent,
119
120 PersistentRecursive,
122}
123
124impl From<AddWatchMode> for proto::AddWatchMode {
125 fn from(mode: AddWatchMode) -> proto::AddWatchMode {
126 match mode {
127 AddWatchMode::Persistent => proto::AddWatchMode::Persistent,
128 AddWatchMode::PersistentRecursive => proto::AddWatchMode::PersistentRecursive,
129 }
130 }
131}
132
133#[derive(Clone, Debug)]
135pub struct CreateOptions<'a> {
136 mode: CreateMode,
137 acls: Acls<'a>,
138 ttl: Option<Duration>,
139}
140
141const TTL_MAX_MILLIS: u128 = 0x00FFFFFFFFFF;
145
146impl<'a> CreateOptions<'a> {
147 pub const fn with_ttl(mut self, ttl: Duration) -> Self {
149 self.ttl = Some(ttl);
150 self
151 }
152
153 fn validate(&'a self) -> Result<()> {
154 if let Some(ref ttl) = self.ttl {
155 if !self.mode.is_persistent() {
156 return Err(Error::BadArguments(&"ttl can only be specified with persistent node"));
157 } else if ttl.is_zero() {
158 return Err(Error::BadArguments(&"ttl is zero"));
159 } else if ttl.as_millis() > TTL_MAX_MILLIS {
160 return Err(Error::BadArguments(&formatcp!("ttl cannot larger than {}", TTL_MAX_MILLIS)));
161 }
162 }
163 if self.acls.is_empty() {
164 return Err(Error::InvalidAcl);
165 }
166 Ok(())
167 }
168
169 fn validate_as_directory(&self) -> Result<()> {
170 self.validate()?;
171 if self.mode.is_ephemeral() {
172 return Err(Error::BadArguments(&"directory node must not be ephemeral"));
173 } else if self.mode.is_sequential() {
174 return Err(Error::BadArguments(&"directory node must not be sequential"));
175 }
176 Ok(())
177 }
178}
179
180#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
184pub struct CreateSequence(i64);
185
186impl std::fmt::Display for CreateSequence {
187 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188 if self.0 <= i32::MAX.into() {
192 write!(f, "{:010}", self.0)
193 } else {
194 write!(f, "{:019}", self.0)
195 }
196 }
197}
198
199impl CreateSequence {
200 pub fn into_i64(self) -> i64 {
201 self.0
202 }
203}
204
205#[derive(Clone, Debug)]
219pub struct Client {
220 chroot: OwnedChroot,
221 version: Version,
222 session: SessionInfo,
223 session_timeout: Duration,
224 requester: mpsc::UnboundedSender<SessionOperation>,
225 state_watcher: StateWatcher,
226}
227
228impl Client {
229 const CONFIG_NODE: &'static str = "/zookeeper/config";
230
231 pub async fn connect(cluster: &str) -> Result<Self> {
233 Self::connector().connect(cluster).await
234 }
235
236 #[deprecated(since = "0.7.0", note = "use Client::connector instead")]
238 pub fn builder() -> ClientBuilder {
239 ClientBuilder::new()
240 }
241
242 pub fn connector() -> Connector {
244 Connector::new()
245 }
246
247 pub(crate) fn new(
248 chroot: OwnedChroot,
249 version: Version,
250 session: SessionInfo,
251 timeout: Duration,
252 requester: mpsc::UnboundedSender<SessionOperation>,
253 state_watcher: StateWatcher,
254 ) -> Client {
255 Client { chroot, version, session, session_timeout: timeout, requester, state_watcher }
256 }
257
258 fn validate_path<'a>(&'a self, path: &'a str) -> Result<ChrootPath<'a>> {
259 ChrootPath::new(self.chroot.as_ref(), path, false)
260 }
261
262 fn validate_sequential_path<'a>(&'a self, path: &'a str) -> Result<ChrootPath<'a>> {
263 ChrootPath::new(self.chroot.as_ref(), path, true)
264 }
265
266 pub fn path(&self) -> &str {
268 self.chroot.path()
269 }
270
271 pub fn session(&self) -> &SessionInfo {
273 &self.session
274 }
275
276 pub fn session_id(&self) -> SessionId {
278 self.session().id()
279 }
280
281 pub fn into_session(self) -> SessionInfo {
283 self.session
284 }
285
286 pub fn session_timeout(&self) -> Duration {
288 self.session_timeout
289 }
290
291 pub fn state(&self) -> SessionState {
293 self.state_watcher.peek_state()
294 }
295
296 pub fn state_watcher(&self) -> StateWatcher {
298 let mut watcher = self.state_watcher.clone();
299 watcher.state();
300 watcher
301 }
302
303 pub fn chroot<'a>(mut self, path: impl Into<Cow<'a, str>>) -> std::result::Result<Client, Client> {
311 if self.chroot.chroot(path) {
312 Ok(self)
313 } else {
314 Err(self)
315 }
316 }
317
318 fn send_request(&self, code: OpCode, body: &impl Record) -> StateReceiver {
319 let request = MarshalledRequest::new(code, body);
320 self.send_marshalled_request(request)
321 }
322
323 fn send_marshalled_request(&self, request: MarshalledRequest) -> StateReceiver {
324 let (operation, receiver) = SessionOperation::new_marshalled(request).with_responser();
325 if let Err(err) = self.requester.unbounded_send(operation) {
326 let state = self.state();
327 err.into_inner().responser.send(Err(state.to_error()));
328 }
329 receiver
330 }
331
332 async fn wait<T, E, F>(result: std::result::Result<F, E>) -> std::result::Result<T, E>
333 where
334 F: Future<Output = std::result::Result<T, E>>, {
335 match result {
336 Err(err) => Err(err),
337 Ok(future) => future.await,
338 }
339 }
340
341 async fn resolve<T, E, F>(result: std::result::Result<Either<F, T>, E>) -> std::result::Result<T, E>
342 where
343 F: Future<Output = std::result::Result<T, E>>, {
344 match result {
345 Err(err) => Err(err),
346 Ok(Right(r)) => Ok(r),
347 Ok(Left(future)) => future.await,
348 }
349 }
350
351 async fn map_wait<T, U, Fu, Fn>(result: Result<Fu>, f: Fn) -> Result<U>
352 where
353 Fu: Future<Output = Result<T>>,
354 Fn: FnOnce(T) -> U, {
355 match result {
356 Err(err) => Err(err),
357 Ok(future) => match future.await {
358 Err(err) => Err(err),
359 Ok(t) => Ok(f(t)),
360 },
361 }
362 }
363
364 async fn retry_on_connection_loss<T, F>(operation: impl Fn() -> F) -> Result<T>
365 where
366 F: Future<Output = Result<T>>, {
367 loop {
368 let future = operation();
369 return match future.await {
370 Err(Error::ConnectionLoss) => continue,
371 result => result,
372 };
373 }
374 }
375
376 fn parse_sequence(client_path: &str, prefix: &str) -> Result<CreateSequence> {
377 if let Some(sequence_path) = client_path.strip_prefix(prefix) {
378 match sequence_path.parse::<i64>() {
379 Err(_) => Err(Error::UnexpectedError(format!("sequential node get no i32 path {}", client_path))),
380 Ok(i) => Ok(CreateSequence(i)),
381 }
382 } else {
383 Err(Error::UnexpectedError(format!(
384 "sequential path {} does not contain prefix path {}",
385 client_path, prefix
386 )))
387 }
388 }
389
390 pub async fn mkdir(&self, path: &str, options: &CreateOptions<'_>) -> Result<()> {
402 options.validate_as_directory()?;
403 self.mkdir_internally(path, options).await
404 }
405
406 async fn mkdir_internally(&self, path: &str, options: &CreateOptions<'_>) -> Result<()> {
407 let mut j = path.len();
408 loop {
409 match self.create(&path[..j], Default::default(), options).await {
410 Ok(_) | Err(Error::NodeExists) => {
411 if j >= path.len() {
412 return Ok(());
413 } else if let Some(i) = path[j + 1..].find('/') {
414 j = j + 1 + i;
415 } else {
416 j = path.len();
417 }
418 },
419 Err(Error::NoNode) => {
420 let i = path[..j].rfind('/').unwrap();
421 if i == 0 {
422 return Err(Error::NoNode);
424 }
425 j = i;
426 },
427 Err(err) => return Err(err),
428 }
429 }
430 }
431
432 pub fn create<'a: 'f, 'b: 'f, 'f>(
446 &'a self,
447 path: &'b str,
448 data: &[u8],
449 options: &CreateOptions<'_>,
450 ) -> impl Future<Output = Result<(Stat, CreateSequence)>> + Send + 'f {
451 Self::wait(self.create_internally(path, data, options))
452 }
453
454 fn create_internally<'a: 'f, 'b: 'f, 'f>(
455 &'a self,
456 path: &'b str,
457 data: &[u8],
458 options: &CreateOptions<'_>,
459 ) -> Result<impl Future<Output = Result<(Stat, CreateSequence)>> + Send + 'f> {
460 options.validate()?;
461 let create_mode = options.mode;
462 let sequential = create_mode.is_sequential();
463 let chroot_path = if sequential { self.validate_sequential_path(path)? } else { self.validate_path(path)? };
464 if chroot_path.is_root() {
465 return Err(Error::BadArguments(&"can not create root node"));
466 }
467 let ttl = options.ttl.map(|ttl| ttl.as_millis() as i64).unwrap_or(0);
468 let op_code = if ttl != 0 {
469 OpCode::CreateTtl
470 } else if create_mode.is_container() {
471 OpCode::CreateContainer
472 } else if self.version >= Version(3, 5, 0) {
473 OpCode::Create2
474 } else {
475 OpCode::Create
476 };
477 let flags = create_mode.as_flags(ttl != 0);
478 let request = CreateRequest { path: chroot_path, data, acls: options.acls, flags, ttl };
479 let receiver = self.send_request(op_code, &request);
480 Ok(async move {
481 let (body, _) = receiver.await?;
482 let mut buf = body.as_slice();
483 let server_path = record::unmarshal_entity::<&str>(&"server path", &mut buf)?;
484 let client_path = util::strip_root_path(server_path, self.chroot.root())?;
485 let sequence = if sequential { Self::parse_sequence(client_path, path)? } else { CreateSequence(-1) };
486 let stat =
487 if op_code == OpCode::Create { Stat::new_invalid() } else { record::unmarshal::<Stat>(&mut buf)? };
488 Ok((stat, sequence))
489 })
490 }
491
492 pub fn delete(&self, path: &str, expected_version: Option<i32>) -> impl Future<Output = Result<()>> + Send {
499 Self::wait(self.delete_internally(path, expected_version))
500 }
501
502 fn delete_internally(&self, path: &str, expected_version: Option<i32>) -> Result<impl Future<Output = Result<()>>> {
503 let chroot_path = self.validate_path(path)?;
504 if chroot_path.is_root() {
505 return Err(Error::BadArguments(&"can not delete root node"));
506 }
507 let request = DeleteRequest { path: chroot_path, version: expected_version.unwrap_or(-1) };
508 let receiver = self.send_request(OpCode::Delete, &request);
509 Ok(async move {
510 receiver.await?;
511 Ok(())
512 })
513 }
514
515 fn delete_background(self, path: String) {
517 asyncs::spawn(async move {
518 self.delete_foreground(&path).await;
519 });
520 }
521
522 async fn delete_foreground(&self, path: &str) {
523 Client::retry_on_connection_loss(|| self.delete(path, None)).await.ignore();
524 }
525
526 fn delete_ephemeral_background(self, prefix: String, unique: bool) {
527 asyncs::spawn(async move {
528 let (parent, tree, name) = util::split_path(&prefix);
529 let mut children = Self::retry_on_connection_loss(|| self.list_children(parent)).await?;
530 if unique {
531 if let Some(i) = children.iter().position(|s| s.starts_with(name)) {
532 self.delete_foreground(&children[i]).await;
533 };
534 return Ok::<(), Error>(());
535 }
536 children.retain(|s| s.starts_with(name));
537 for child in children.iter_mut() {
538 child.insert_str(0, tree);
539 }
540 let results = Self::retry_on_connection_loss(|| {
541 let mut reader = self.new_multi_reader();
542 for child in children.iter() {
543 reader.add_get_data(child).unwrap();
544 }
545 reader.commit()
546 })
547 .await?;
548 for (i, result) in results.into_iter().enumerate() {
549 let MultiReadResult::Data { stat, .. } = result else {
550 continue;
552 };
553 if stat.ephemeral_owner == self.session_id().0 {
554 self.delete_foreground(&children[i]).await;
555 break;
556 }
557 }
558 Ok(())
559 });
560 }
561
562 fn get_data_internally(
563 &self,
564 chroot: Chroot,
565 path: &str,
566 watch: bool,
567 ) -> Result<impl Future<Output = Result<(Vec<u8>, Stat, WatchReceiver)>> + Send> {
568 let chroot_path = ChrootPath::new(chroot, path, false)?;
569 let request = GetRequest { path: chroot_path, watch };
570 let receiver = self.send_request(OpCode::GetData, &request);
571 Ok(async move {
572 let (mut body, watcher) = receiver.await?;
573 let data_len = body.len() - Stat::record_len();
574 let mut stat_buf = &body[data_len..];
575 let stat = record::unmarshal(&mut stat_buf)?;
576 body.truncate(data_len);
577 drop(body.drain(..4));
578 Ok((body, stat, watcher))
579 })
580 }
581
582 pub fn get_data(&self, path: &str) -> impl Future<Output = Result<(Vec<u8>, Stat)>> + Send {
587 let result = self.get_data_internally(self.chroot.as_ref(), path, false);
588 Self::map_wait(result, |(data, stat, _)| (data, stat))
589 }
590
591 pub fn get_and_watch_data(
601 &self,
602 path: &str,
603 ) -> impl Future<Output = Result<(Vec<u8>, Stat, OneshotWatcher)>> + Send + '_ {
604 let result = self.get_data_internally(self.chroot.as_ref(), path, true);
605 Self::map_wait(result, |(data, stat, watcher)| (data, stat, watcher.into_oneshot(&self.chroot)))
606 }
607
608 fn check_stat_internally(
609 &self,
610 path: &str,
611 watch: bool,
612 ) -> Result<impl Future<Output = Result<(Option<Stat>, WatchReceiver)>>> {
613 let chroot_path = self.validate_path(path)?;
614 let request = ExistsRequest { path: chroot_path, watch };
615 let receiver = self.send_request(OpCode::Exists, &request);
616 Ok(async move {
617 let (body, watcher) = receiver.await?;
618 let mut buf = body.as_slice();
619 let stat = record::try_deserialize(&mut buf)?;
620 Ok((stat, watcher))
621 })
622 }
623
624 pub fn check_stat(&self, path: &str) -> impl Future<Output = Result<Option<Stat>>> + Send {
626 Self::map_wait(self.check_stat_internally(path, false), |(stat, _)| stat)
627 }
628
629 pub fn check_and_watch_stat(
636 &self,
637 path: &str,
638 ) -> impl Future<Output = Result<(Option<Stat>, OneshotWatcher)>> + Send + '_ {
639 let result = self.check_stat_internally(path, true);
640 Self::map_wait(result, |(stat, watcher)| (stat, watcher.into_oneshot(&self.chroot)))
641 }
642
643 pub fn set_data(
650 &self,
651 path: &str,
652 data: &[u8],
653 expected_version: Option<i32>,
654 ) -> impl Future<Output = Result<Stat>> + Send {
655 Self::wait(self.set_data_internally(path, data, expected_version))
656 }
657
658 pub fn set_data_internally(
659 &self,
660 path: &str,
661 data: &[u8],
662 expected_version: Option<i32>,
663 ) -> Result<impl Future<Output = Result<Stat>>> {
664 let chroot_path = self.validate_path(path)?;
665 let request = SetDataRequest { path: chroot_path, data, version: expected_version.unwrap_or(-1) };
666 let receiver = self.send_request(OpCode::SetData, &request);
667 Ok(async move {
668 let (body, _) = receiver.await?;
669 let mut buf = body.as_slice();
670 let stat: Stat = record::unmarshal(&mut buf)?;
671 Ok(stat)
672 })
673 }
674
675 fn list_children_internally(
676 &self,
677 path: &str,
678 watch: bool,
679 ) -> Result<impl Future<Output = Result<(Vec<String>, WatchReceiver)>>> {
680 let chroot_path = self.validate_path(path)?;
681 let request = GetChildrenRequest { path: chroot_path, watch };
682 let receiver = self.send_request(OpCode::GetChildren, &request);
683 Ok(async move {
684 let (body, watcher) = receiver.await?;
685 let mut buf = body.as_slice();
686 let children = record::unmarshal_entity::<Vec<String>>(&"children paths", &mut buf)?;
687 Ok((children, watcher))
688 })
689 }
690
691 pub fn list_children(&self, path: &str) -> impl Future<Output = Result<Vec<String>>> + Send + '_ {
696 Self::map_wait(self.list_children_internally(path, false), |(children, _)| children)
697 }
698
699 pub fn list_and_watch_children(
710 &self,
711 path: &str,
712 ) -> impl Future<Output = Result<(Vec<String>, OneshotWatcher)>> + Send + '_ {
713 let result = self.list_children_internally(path, true);
714 Self::map_wait(result, |(children, watcher)| (children, watcher.into_oneshot(&self.chroot)))
715 }
716
717 fn get_children_internally(
718 &self,
719 path: &str,
720 watch: bool,
721 ) -> Result<impl Future<Output = Result<(Vec<String>, Stat, WatchReceiver)>>> {
722 let chroot_path = self.validate_path(path)?;
723 let request = GetChildrenRequest { path: chroot_path, watch };
724 let receiver = self.send_request(OpCode::GetChildren2, &request);
725 Ok(async move {
726 let (body, watcher) = receiver.await?;
727 let mut buf = body.as_slice();
728 let response = record::unmarshal::<GetChildren2Response>(&mut buf)?;
729 Ok((response.children, response.stat, watcher))
730 })
731 }
732
733 pub fn get_children(&self, path: &str) -> impl Future<Output = Result<(Vec<String>, Stat)>> + Send {
738 let result = self.get_children_internally(path, false);
739 Self::map_wait(result, |(children, stat, _)| (children, stat))
740 }
741
742 pub fn get_and_watch_children(
753 &self,
754 path: &str,
755 ) -> impl Future<Output = Result<(Vec<String>, Stat, OneshotWatcher)>> + Send + '_ {
756 let result = self.get_children_internally(path, true);
757 Self::map_wait(result, |(children, stat, watcher)| (children, stat, watcher.into_oneshot(&self.chroot)))
758 }
759
760 pub fn count_descendants_number(&self, path: &str) -> impl Future<Output = Result<usize>> + Send {
765 Self::wait(self.count_descendants_number_internally(path))
766 }
767
768 fn count_descendants_number_internally(&self, path: &str) -> Result<impl Future<Output = Result<usize>>> {
769 let chroot_path = self.validate_path(path)?;
770 let receiver = self.send_request(OpCode::GetAllChildrenNumber, &chroot_path);
771 Ok(async move {
772 let (body, _) = receiver.await?;
773 let mut buf = body.as_slice();
774 let n = record::unmarshal_entity::<i32>(&"all children number", &mut buf)?;
775 Ok(n as usize)
776 })
777 }
778
779 pub fn list_ephemerals(&self, path: &str) -> impl Future<Output = Result<Vec<String>>> + Send + '_ {
786 Self::wait(self.list_ephemerals_internally(path))
787 }
788
789 fn list_ephemerals_internally(&self, path: &str) -> Result<impl Future<Output = Result<Vec<String>>> + Send + '_> {
790 let path = self.validate_path(path)?;
791 let receiver = self.send_request(OpCode::GetEphemerals, &path);
792 Ok(async move {
793 let (body, _) = receiver.await?;
794 let mut buf = body.as_slice();
795 let mut ephemerals = record::unmarshal_entity::<Vec<String>>(&"ephemerals", &mut buf)?;
796 for ephemeral_path in ephemerals.iter_mut() {
797 util::drain_root_path(ephemeral_path, self.chroot.root())?;
798 }
799 Ok(ephemerals)
800 })
801 }
802
803 pub fn get_acl(&self, path: &str) -> impl Future<Output = Result<(Vec<Acl>, Stat)>> + Send + '_ {
808 Self::wait(self.get_acl_internally(path))
809 }
810
811 fn get_acl_internally(&self, path: &str) -> Result<impl Future<Output = Result<(Vec<Acl>, Stat)>>> {
812 let chroot_path = self.validate_path(path)?;
813 let receiver = self.send_request(OpCode::GetACL, &chroot_path);
814 Ok(async move {
815 let (body, _) = receiver.await?;
816 let mut buf = body.as_slice();
817 let response: GetAclResponse = record::unmarshal(&mut buf)?;
818 Ok((response.acl, response.stat))
819 })
820 }
821
822 pub fn set_acl(
828 &self,
829 path: &str,
830 acl: &[Acl],
831 expected_acl_version: Option<i32>,
832 ) -> impl Future<Output = Result<Stat>> + Send + '_ {
833 Self::wait(self.set_acl_internally(path, acl, expected_acl_version))
834 }
835
836 fn set_acl_internally(
837 &self,
838 path: &str,
839 acl: &[Acl],
840 expected_acl_version: Option<i32>,
841 ) -> Result<impl Future<Output = Result<Stat>>> {
842 let chroot_path = self.validate_path(path)?;
843 let request = SetAclRequest { path: chroot_path, acl, version: expected_acl_version.unwrap_or(-1) };
844 let receiver = self.send_request(OpCode::SetACL, &request);
845 Ok(async move {
846 let (body, _) = receiver.await?;
847 let mut buf = body.as_slice();
848 let stat: Stat = record::unmarshal(&mut buf)?;
849 Ok(stat)
850 })
851 }
852
853 pub fn watch(&self, path: &str, mode: AddWatchMode) -> impl Future<Output = Result<PersistentWatcher>> + Send + '_ {
868 Self::wait(self.watch_internally(path, mode))
869 }
870
871 fn watch_internally(
872 &self,
873 path: &str,
874 mode: AddWatchMode,
875 ) -> Result<impl Future<Output = Result<PersistentWatcher>> + Send + '_> {
876 let chroot_path = self.validate_path(path)?;
877 let proto_mode = proto::AddWatchMode::from(mode);
878 let request = PersistentWatchRequest { path: chroot_path, mode: proto_mode.into() };
879 let receiver = self.send_request(OpCode::AddWatch, &request);
880 Ok(async move {
881 let (_, watcher) = receiver.await?;
882 Ok(watcher.into_persistent(&self.chroot))
883 })
884 }
885
886 pub fn sync(&self, path: &str) -> impl Future<Output = Result<()>> + Send + '_ {
897 Self::wait(self.sync_internally(path))
898 }
899
900 fn sync_internally(&self, path: &str) -> Result<impl Future<Output = Result<()>>> {
901 let chroot_path = self.validate_path(path)?;
902 let request = SyncRequest { path: chroot_path };
903 let receiver = self.send_request(OpCode::Sync, &request);
904 Ok(async move {
905 let (body, _) = receiver.await?;
906 let mut buf = body.as_slice();
907 record::unmarshal_entity::<&str>(&"server path", &mut buf)?;
908 Ok(())
909 })
910 }
911
912 pub fn auth(&self, scheme: String, auth: Vec<u8>) -> impl Future<Output = Result<()>> + Send + '_ {
926 let request = AuthPacket { scheme, auth };
927 let receiver = self.send_request(OpCode::Auth, &request);
928 async move {
929 receiver.await?;
930 Ok(())
931 }
932 }
933
934 pub fn list_auth_users(&self) -> impl Future<Output = Result<Vec<AuthUser>>> + Send {
944 let receiver = self.send_request(OpCode::WhoAmI, &());
945 async move {
946 let (body, _) = receiver.await?;
947 let mut buf = body.as_slice();
948 let authed_users = record::unmarshal_entity::<Vec<AuthUser>>(&"authed users", &mut buf)?;
949 Ok(authed_users)
950 }
951 }
952
953 pub fn get_config(&self) -> impl Future<Output = Result<(Vec<u8>, Stat)>> + Send {
955 let result = self.get_data_internally(Chroot::default(), Self::CONFIG_NODE, false);
956 Self::map_wait(result, |(data, stat, _)| (data, stat))
957 }
958
959 pub fn get_and_watch_config(&self) -> impl Future<Output = Result<(Vec<u8>, Stat, OneshotWatcher)>> + Send {
961 let result = self.get_data_internally(Chroot::default(), Self::CONFIG_NODE, true);
962 Self::map_wait(result, |(data, stat, watcher)| (data, stat, watcher.into_oneshot(&OwnedChroot::default())))
963 }
964
965 pub fn update_ensemble<'a, I: Iterator<Item = &'a str> + Clone>(
973 &self,
974 update: EnsembleUpdate<'a, I>,
975 expected_zxid: Option<i64>,
976 ) -> impl Future<Output = Result<(Vec<u8>, Stat)>> + Send {
977 let request = ReconfigRequest { update, version: expected_zxid.unwrap_or(-1) };
978 let receiver = self.send_request(OpCode::Reconfig, &request);
979 async move {
980 let (mut body, _) = receiver.await?;
981 let mut buf = body.as_slice();
982 let data: &str = record::unmarshal_entity(&"reconfig data", &mut buf)?;
983 let stat = record::unmarshal_entity(&"reconfig stat", &mut buf)?;
984 let data_len = data.len();
985 body.truncate(data_len + 4);
986 drop(body.drain(..4));
987 Ok((body, stat))
988 }
989 }
990
991 pub fn new_multi_reader(&self) -> MultiReader<'_> {
993 MultiReader::new(self)
994 }
995
996 pub fn new_multi_writer(&self) -> MultiWriter<'_> {
998 MultiWriter::new(self)
999 }
1000
1001 pub fn new_check_writer(&self, path: &str, version: Option<i32>) -> Result<CheckWriter<'_>> {
1004 let mut writer = self.new_multi_writer();
1005 writer.add_check_version(path, version.unwrap_or(-1))?;
1006 Ok(CheckWriter { writer })
1007 }
1008
1009 async fn create_lock(
1010 &self,
1011 prefix: LockPrefix<'_>,
1012 data: &[u8],
1013 options: LockOptions<'_>,
1014 ) -> Result<(String, usize)> {
1015 let kind = prefix.kind();
1016 let prefix = prefix.into();
1017 self.validate_sequential_path(&prefix)?;
1018 let (parent, _, _) = util::split_path(&prefix);
1019 let guard = LockingGuard { zk: self, prefix: &prefix, unique: kind.is_unique() };
1020 loop {
1021 let mut result = self.create(&prefix, data, &CreateMode::EphemeralSequential.with_acls(options.acls)).await;
1022 if result == Err(Error::NoNode) {
1023 if let Some(options) = &options.parent {
1024 match Self::retry_on_connection_loss(|| self.mkdir_internally(parent, options)).await {
1025 Ok(_) => continue,
1026 Err(Error::NoNode) => result = Err(Error::NoNode),
1027 Err(err) => return Err(err),
1028 }
1029 }
1030 }
1031 let sequence = match result {
1032 Err(Error::ConnectionLoss) => {
1033 if let Some(sequence) = self.find_lock(&prefix, kind).await? {
1034 sequence
1035 } else {
1036 continue;
1037 }
1038 },
1039 Err(err) => {
1040 if err.has_no_data_change() {
1041 std::mem::forget(guard);
1042 return Err(err);
1043 } else {
1044 return Err(err);
1045 }
1046 },
1047 Ok((_stat, sequence)) => sequence,
1048 };
1049 std::mem::forget(guard);
1050 let prefix_len = prefix.len();
1051 let mut path = prefix;
1052 write!(&mut path, "{}", sequence).unwrap();
1053 let sequence_len = path.len() - prefix_len;
1054 return Ok((path, sequence_len));
1055 }
1056 }
1057
1058 async fn find_lock(&self, prefix: &str, kind: LockPrefixKind<'_>) -> Result<Option<CreateSequence>> {
1059 let (parent, tree, name) = util::split_path(prefix);
1060 let mut children = Self::retry_on_connection_loss(|| self.list_children(parent)).await?;
1061 if kind.is_unique() {
1062 let Some(i) = children.iter().position(|s| s.starts_with(name)) else {
1063 return Ok(None);
1064 };
1065 let sequence = Self::parse_sequence(&children[i], name)?;
1066 return Ok(Some(sequence));
1067 }
1068 children.retain(|s| s.starts_with(name));
1069 if children.is_empty() {
1070 return Ok(None);
1071 }
1072 for child in children.iter_mut() {
1073 child.insert_str(0, tree);
1074 }
1075 let results = Self::retry_on_connection_loss(|| {
1076 let mut reader = self.new_multi_reader();
1077 for child in children.iter() {
1078 reader.add_get_data(child).unwrap();
1079 }
1080 reader.commit()
1081 })
1082 .await?;
1083 for (i, result) in results.into_iter().enumerate() {
1084 let MultiReadResult::Data { stat, .. } = result else {
1085 continue;
1087 };
1088 if stat.ephemeral_owner == self.session_id().0 {
1089 let sequence = Self::parse_sequence(&children[i], name)?;
1090 return Ok(Some(sequence));
1091 }
1092 }
1093 Ok(None)
1094 }
1095
1096 async fn wait_lock(&self, lock: &str, kind: LockPrefixKind<'_>, sequence_len: usize) -> Result<()> {
1097 let (parent, tree, this) = util::split_path(lock);
1098 loop {
1099 let mut children = Self::retry_on_connection_loss(|| self.list_children(parent)).await?;
1100 children.retain(|s| {
1101 s.len() >= sequence_len && kind.filter(s) && s[s.len() - sequence_len..].parse::<i32>().is_ok()
1102 });
1103 children.sort_unstable_by(|a, b| a[a.len() - sequence_len..].cmp(&b[b.len() - sequence_len..]));
1104 match children.binary_search_by(|a| a[a.len() - sequence_len..].cmp(&this[this.len() - sequence_len..])) {
1105 Ok(0) => return Ok(()),
1106 Ok(i) => {
1107 let mut child = children.swap_remove(i - 1);
1108 child.insert_str(0, tree);
1109 let watcher = match Self::retry_on_connection_loss(|| self.get_and_watch_data(&child)).await {
1110 Err(Error::NoNode) => continue,
1111 Err(err) => return Err(err),
1112 Ok((_data, _stat, watcher)) => watcher,
1113 };
1114 watcher.changed().await;
1115 },
1116 Err(_) => return Err(Error::RuntimeInconsistent),
1117 }
1118 }
1119 }
1120
1121 pub async fn lock(
1152 &self,
1153 prefix: LockPrefix<'_>,
1154 data: &[u8],
1155 options: impl Into<LockOptions<'_>>,
1156 ) -> Result<LockClient<'_>> {
1157 let options = options.into();
1158 if options.acls.is_empty() {
1159 return Err(Error::InvalidAcl);
1160 }
1161 let prefix_kind = prefix.kind();
1162 let (lock, sequence_len) = self.create_lock(prefix, data, options).await?;
1163 let client = LockClient { client: self, lock: Cow::from(lock) };
1164 match self.wait_lock(&client.lock, prefix_kind, sequence_len).await {
1165 Err(err @ (Error::RuntimeInconsistent | Error::SessionExpired)) => {
1166 std::mem::forget(client);
1167 Err(err)
1168 },
1169 Err(err) => Err(err),
1170 Ok(_) => Ok(client),
1171 }
1172 }
1173}
1174
1175#[derive(Clone, Debug)]
1178pub struct LockOptions<'a> {
1179 acls: Acls<'a>,
1180 parent: Option<CreateOptions<'a>>,
1181}
1182
1183impl<'a> LockOptions<'a> {
1184 pub fn new(acls: Acls<'a>) -> Self {
1185 Self { acls, parent: None }
1186 }
1187
1188 pub fn with_ancestor_options(mut self, options: CreateOptions<'a>) -> Result<Self> {
1194 options.validate_as_directory()?;
1195 self.parent = Some(options);
1196 Ok(self)
1197 }
1198}
1199
1200impl<'a> From<Acls<'a>> for LockOptions<'a> {
1201 fn from(acls: Acls<'a>) -> Self {
1202 LockOptions::new(acls)
1203 }
1204}
1205
1206#[derive(Clone, Copy)]
1207enum LockPrefixKind<'a> {
1208 Curator { lock_name: &'a str },
1209 Custom { lock_name: &'a str },
1210 Shared { prefix: &'a str },
1211}
1212
1213impl LockPrefixKind<'_> {
1214 fn filter(&self, name: &str) -> bool {
1215 match self {
1216 Self::Curator { lock_name } => name.contains(lock_name),
1217 Self::Custom { lock_name } => name.contains(lock_name),
1218 Self::Shared { prefix } => name.starts_with(prefix),
1219 }
1220 }
1221
1222 fn is_unique(&self) -> bool {
1223 matches!(self, Self::Curator { .. })
1224 }
1225}
1226
1227#[derive(Debug)]
1228enum LockPrefixInner<'a> {
1229 Curator { dir: &'a str, name: &'a str },
1230 Custom { prefix: String, name: &'a str },
1231 Shared { prefix: &'a str },
1232}
1233
1234#[derive(Debug)]
1243pub struct LockPrefix<'a> {
1244 inner: LockPrefixInner<'a>,
1245}
1246
1247impl<'a> LockPrefix<'a> {
1248 pub fn new_curator(dir: &'a str, name: &'a str) -> Result<Self> {
1255 crate::util::validate_path(Chroot::default(), dir, false)?;
1256 if name.find('/').is_some() {
1257 return Err(Error::BadArguments(&"lock name must not contain /"));
1258 }
1259 Ok(Self { inner: LockPrefixInner::Curator { dir, name } })
1260 }
1261
1262 pub fn new_shared(prefix: &'a str) -> Result<Self> {
1274 crate::util::validate_path(Chroot::default(), prefix, true)?;
1275 Ok(Self { inner: LockPrefixInner::Shared { prefix } })
1276 }
1277
1278 pub fn new_custom(prefix: String, name: &'a str) -> Result<Self> {
1294 crate::util::validate_path(Chroot::default(), &prefix, true)?;
1295 if !name.is_empty() {
1296 let (_dir, _tree, this) = util::split_path(&prefix);
1297 if !this.contains(name) {
1298 return Err(Error::BadArguments(&"lock path prefix must contain lock name"));
1299 }
1300 }
1301 Ok(Self { inner: LockPrefixInner::Custom { prefix, name } })
1302 }
1303
1304 fn kind(&self) -> LockPrefixKind<'a> {
1305 match &self.inner {
1306 LockPrefixInner::Curator { name, .. } => LockPrefixKind::Curator { lock_name: name },
1307 LockPrefixInner::Shared { prefix } => {
1308 let (_parent, _tree, name) = util::split_path(prefix);
1309 LockPrefixKind::Shared { prefix: name }
1310 },
1311 LockPrefixInner::Custom { name, .. } => LockPrefixKind::Custom { lock_name: name },
1312 }
1313 }
1314
1315 fn into(self) -> String {
1316 match self.inner {
1317 LockPrefixInner::Curator { dir, name } => format!("{}/_c_{}-{}", dir, uuid::Uuid::new_v4(), name),
1318 LockPrefixInner::Shared { prefix } => prefix.to_string(),
1319 LockPrefixInner::Custom { prefix, .. } => prefix,
1320 }
1321 }
1322}
1323
1324struct LockingGuard<'a> {
1325 zk: &'a Client,
1326 prefix: &'a str,
1327 unique: bool,
1328}
1329
1330impl Drop for LockingGuard<'_> {
1331 fn drop(&mut self) {
1332 self.zk.clone().delete_ephemeral_background(self.prefix.to_string(), self.unique);
1333 }
1334}
1335
1336#[derive(Debug)]
1338pub struct LockClient<'a> {
1339 client: &'a Client,
1340 lock: Cow<'a, str>,
1341}
1342
1343impl<'a> LockClient<'a> {
1344 async fn resolve_one_write(
1345 future: impl Future<Output = std::result::Result<Vec<MultiWriteResult>, CheckWriteError>>,
1346 ) -> Result<MultiWriteResult> {
1347 let mut results = future.await?;
1348 Ok(results.remove(0))
1349 }
1350
1351 pub fn client(&self) -> &'a Client {
1353 self.client
1354 }
1355
1356 pub fn lock_path(&self) -> &str {
1361 &self.lock
1362 }
1363
1364 pub fn create(
1372 &self,
1373 path: &str,
1374 data: &[u8],
1375 options: &CreateOptions<'_>,
1376 ) -> impl Future<Output = Result<(Stat, CreateSequence)>> + Send + 'a {
1377 Client::wait(self.create_internally(path, data, options))
1378 }
1379
1380 fn create_internally(
1381 &self,
1382 path: &str,
1383 data: &[u8],
1384 options: &CreateOptions<'_>,
1385 ) -> Result<impl Future<Output = Result<(Stat, CreateSequence)>> + Send + 'a> {
1386 let mut writer = self.client.new_check_writer(&self.lock, None)?;
1387 writer.add_create(path, data, options)?;
1388 let write = writer.commit();
1389 let path_len = path.len();
1394 Ok(async move {
1395 let result = Self::resolve_one_write(write).await?;
1396 let (created_path, stat) = result.into_create()?;
1397 let sequence = if created_path.len() <= path_len {
1398 CreateSequence(-1)
1399 } else {
1400 Client::parse_sequence(&created_path, &created_path[..path_len])?
1401 };
1402 Ok((stat, sequence))
1403 })
1404 }
1405
1406 pub fn set_data(
1408 &self,
1409 path: &str,
1410 data: &[u8],
1411 expected_version: Option<i32>,
1412 ) -> impl Future<Output = Result<Stat>> + Send + 'a {
1413 Client::wait(self.set_data_internally(path, data, expected_version))
1414 }
1415
1416 fn set_data_internally(
1417 &self,
1418 path: &str,
1419 data: &[u8],
1420 expected_version: Option<i32>,
1421 ) -> Result<impl Future<Output = Result<Stat>> + Send + 'a> {
1422 let mut writer = self.new_check_writer();
1423 writer.add_set_data(path, data, expected_version)?;
1424 let write = writer.commit();
1425 Ok(async move {
1426 let result = Self::resolve_one_write(write).await?;
1427 let stat = result.into_set_data()?;
1428 Ok(stat)
1429 })
1430 }
1431
1432 pub fn delete(&self, path: &str, expected_version: Option<i32>) -> impl Future<Output = Result<()>> + Send + 'a {
1434 Client::wait(self.delete_internally(path, expected_version))
1435 }
1436
1437 fn delete_internally(
1438 &self,
1439 path: &str,
1440 expected_version: Option<i32>,
1441 ) -> Result<impl Future<Output = Result<()>> + Send + 'a> {
1442 let mut writer = self.new_check_writer();
1443 writer.add_delete(path, expected_version)?;
1444 let write = writer.commit();
1445 Ok(async move {
1446 let result = Self::resolve_one_write(write).await?;
1447 result.into_delete()
1448 })
1449 }
1450
1451 pub fn new_check_writer(&self) -> CheckWriter<'a> {
1453 unsafe { self.client.new_check_writer(&self.lock, None).unwrap_unchecked() }
1454 }
1455
1456 pub fn into_owned(self) -> OwnedLockClient {
1458 let client = self.client.clone();
1459 let mut drop = ManuallyDrop::new(self);
1460 let lock = std::mem::take(drop.lock.to_mut());
1461 OwnedLockClient { client: ManuallyDrop::new(client), lock }
1462 }
1463}
1464
1465impl Drop for LockClient<'_> {
1467 fn drop(&mut self) {
1468 let path = std::mem::take(self.lock.to_mut());
1469 let client = self.client.clone();
1470 client.delete_background(path);
1471 }
1472}
1473
1474#[derive(Clone, Debug)]
1476pub struct OwnedLockClient {
1477 client: ManuallyDrop<Client>,
1478 lock: String,
1479}
1480
1481impl OwnedLockClient {
1482 fn lock_client(&self) -> std::mem::ManuallyDrop<LockClient<'_>> {
1483 std::mem::ManuallyDrop::new(LockClient { client: &self.client, lock: Cow::from(&self.lock) })
1484 }
1485
1486 pub fn client(&self) -> &Client {
1488 &self.client
1489 }
1490
1491 pub fn lock_path(&self) -> &str {
1493 &self.lock
1494 }
1495
1496 pub fn create<'a: 'f, 'b: 'f, 'f>(
1498 &'a self,
1499 path: &'b str,
1500 data: &[u8],
1501 options: &CreateOptions<'_>,
1502 ) -> impl Future<Output = Result<(Stat, CreateSequence)>> + Send + 'f {
1503 self.lock_client().create(path, data, options)
1504 }
1505
1506 pub fn set_data(
1508 &self,
1509 path: &str,
1510 data: &[u8],
1511 expected_version: Option<i32>,
1512 ) -> impl Future<Output = Result<Stat>> + Send + '_ {
1513 self.lock_client().set_data(path, data, expected_version)
1514 }
1515
1516 pub fn delete(&self, path: &str, expected_version: Option<i32>) -> impl Future<Output = Result<()>> + Send + '_ {
1518 self.lock_client().delete(path, expected_version)
1519 }
1520
1521 pub fn new_check_writer(&self) -> CheckWriter<'_> {
1523 unsafe { self.client.new_check_writer(&self.lock, None).unwrap_unchecked() }
1524 }
1525}
1526
1527impl Drop for OwnedLockClient {
1529 fn drop(&mut self) {
1530 let client = unsafe { ManuallyDrop::take(&mut self.client) };
1531 let path = std::mem::take(&mut self.lock);
1532 client.delete_background(path);
1533 }
1534}
1535
1536#[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
1537pub(crate) struct Version(u32, u32, u32);
1538
1539#[derive(Clone, Debug)]
1543pub struct Connector {
1544 #[cfg(feature = "tls")]
1545 tls: Option<TlsOptions>,
1546 #[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
1547 sasl: Option<SaslOptions>,
1548 authes: Vec<AuthPacket>,
1549 session: Option<SessionInfo>,
1550 readonly: bool,
1551 detached: bool,
1552 fail_eagerly: bool,
1553 server_version: Version,
1554 session_timeout: Duration,
1555 connection_timeout: Duration,
1556}
1557
1558impl Connector {
1559 fn new() -> Self {
1560 Self {
1561 #[cfg(feature = "tls")]
1562 tls: None,
1563 #[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
1564 sasl: None,
1565 authes: Default::default(),
1566 session: None,
1567 readonly: false,
1568 detached: false,
1569 fail_eagerly: false,
1570 server_version: Version(u32::MAX, u32::MAX, u32::MAX),
1571 session_timeout: Duration::ZERO,
1572 connection_timeout: Duration::ZERO,
1573 }
1574 }
1575
1576 pub fn session_timeout(&mut self, timeout: Duration) -> &mut Self {
1580 self.session_timeout = timeout;
1581 self
1582 }
1583
1584 pub fn connection_timeout(&mut self, timeout: Duration) -> &mut Self {
1588 self.connection_timeout = timeout;
1589 self
1590 }
1591
1592 pub fn readonly(&mut self, readonly: bool) -> &mut Self {
1594 self.readonly = readonly;
1595 self
1596 }
1597
1598 pub fn auth(&mut self, scheme: String, auth: Vec<u8>) -> &mut Self {
1600 self.authes.push(AuthPacket { scheme, auth });
1601 self
1602 }
1603
1604 pub fn session(&mut self, session: SessionInfo) -> &mut Self {
1606 self.session = Some(session);
1607 self
1608 }
1609
1610 pub fn server_version(&mut self, major: u32, minor: u32, patch: u32) -> &mut Self {
1620 self.server_version = Version(major, minor, patch);
1621 self
1622 }
1623
1624 pub fn detached(&mut self) -> &mut Self {
1626 self.detached = true;
1627 self
1628 }
1629
1630 #[cfg(feature = "tls")]
1632 pub fn tls(&mut self, options: TlsOptions) -> &mut Self {
1633 self.tls = Some(options);
1634 self
1635 }
1636
1637 #[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
1639 pub fn sasl(&mut self, options: impl Into<SaslOptions>) -> &mut Self {
1640 self.sasl = Some(options.into());
1641 self
1642 }
1643
1644 pub fn fail_eagerly(&mut self) -> &mut Self {
1649 self.fail_eagerly = true;
1650 self
1651 }
1652
1653 #[instrument(name = "connect", skip_all, fields(session))]
1654 async fn connect_internally(&mut self, secure: bool, cluster: &str) -> Result<Client> {
1655 let (endpoints, chroot) = endpoint::parse_connect_string(cluster, secure)?;
1656 let builder = Session::builder()
1657 .with_session(self.session.take())
1658 .with_authes(&self.authes)
1659 .with_readonly(self.readonly)
1660 .with_detached(self.detached)
1661 .with_session_timeout(self.session_timeout)
1662 .with_connection_timeout(self.connection_timeout);
1663 #[cfg(feature = "tls")]
1664 let builder = builder.with_tls(self.tls.take());
1665 #[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
1666 let builder = builder.with_sasl(self.sasl.take());
1667 let (mut session, state_receiver) = builder.build()?;
1668 let mut endpoints = IterableEndpoints::from(endpoints.as_slice());
1669 endpoints.reset();
1670 if !self.fail_eagerly {
1671 endpoints.cycle();
1672 }
1673 let mut buf = Vec::with_capacity(4096);
1674 let mut depot = Depot::new();
1675 let conn = session.start(&mut endpoints, &mut buf, &mut depot).await?;
1676 let (sender, receiver) = mpsc::unbounded();
1677 let session_info = session.session.clone();
1678 let session_timeout = session.session_timeout;
1679 let mut state_watcher = StateWatcher::new(state_receiver);
1680 state_watcher.state();
1682 asyncs::spawn(async move {
1683 session.serve(endpoints, conn, buf, depot, receiver).await;
1684 });
1685 let client =
1686 Client::new(chroot.to_owned(), self.server_version, session_info, session_timeout, sender, state_watcher);
1687 Ok(client)
1688 }
1689
1690 #[cfg(feature = "tls")]
1695 pub async fn secure_connect(&mut self, cluster: &str) -> Result<Client> {
1696 self.connect_internally(true, cluster).await
1697 }
1698
1699 pub async fn connect(&mut self, cluster: &str) -> Result<Client> {
1715 self.connect_internally(false, cluster).await
1716 }
1717}
1718
1719#[derive(Clone, Debug)]
1721pub struct ClientBuilder {
1722 connector: Connector,
1723}
1724
1725impl ClientBuilder {
1726 fn new() -> Self {
1727 Self { connector: Connector::new() }
1728 }
1729
1730 pub fn with_session_timeout(&mut self, timeout: Duration) -> &mut Self {
1734 self.connector.session_timeout(timeout);
1735 self
1736 }
1737
1738 pub fn with_connection_timeout(&mut self, timeout: Duration) -> &mut Self {
1742 self.connector.connection_timeout(timeout);
1743 self
1744 }
1745
1746 pub fn with_readonly(&mut self, readonly: bool) -> &mut ClientBuilder {
1748 self.connector.readonly = readonly;
1749 self
1750 }
1751
1752 pub fn with_auth(&mut self, scheme: String, auth: Vec<u8>) -> &mut ClientBuilder {
1754 self.connector.auth(scheme, auth);
1755 self
1756 }
1757
1758 pub fn assume_server_version(&mut self, major: u32, minor: u32, patch: u32) -> &mut Self {
1768 self.connector.server_version(major, minor, patch);
1769 self
1770 }
1771
1772 pub fn detach(&mut self) -> &mut Self {
1774 self.connector.detached();
1775 self
1776 }
1777
1778 pub async fn connect(&mut self, cluster: &str) -> Result<Client> {
1784 self.connector.connect(cluster).await
1785 }
1786}
1787
1788trait MultiBuffer {
1789 fn buffer(&mut self) -> &mut Vec<u8>;
1790
1791 fn op_code() -> OpCode;
1792
1793 fn build_request(&mut self) -> MarshalledRequest {
1794 let buffer = self.buffer();
1795 if buffer.is_empty() {
1796 return Default::default();
1797 }
1798 let header = MultiHeader { op: OpCode::Error, done: true, err: -1 };
1799 buffer.append_record(&header);
1800 buffer.finish();
1801 MarshalledRequest(std::mem::take(buffer))
1802 }
1803
1804 fn add_operation(&mut self, op: OpCode, request: &impl Record) {
1805 let buffer = self.buffer();
1806 if buffer.is_empty() {
1807 let n = RequestHeader::record_len() + MultiHeader::record_len() + request.serialized_len();
1808 buffer.prepare_and_reserve(n);
1809 buffer.append_record(&RequestHeader::with_code(Self::op_code()));
1810 }
1811 let header = MultiHeader { op, done: false, err: -1 };
1812 self.buffer().append_record2(&header, request);
1813 }
1814}
1815
1816#[non_exhaustive]
1818#[derive(Debug)]
1819pub enum MultiReadResult {
1820 Data { data: Vec<u8>, stat: Stat },
1822
1823 Children { children: Vec<String> },
1825
1826 Error { err: Error },
1828}
1829
1830pub struct MultiReader<'a> {
1832 client: &'a Client,
1833 buf: Vec<u8>,
1834}
1835
1836impl MultiBuffer for MultiReader<'_> {
1837 fn buffer(&mut self) -> &mut Vec<u8> {
1838 &mut self.buf
1839 }
1840
1841 fn op_code() -> OpCode {
1842 OpCode::MultiRead
1843 }
1844}
1845
1846impl<'a> MultiReader<'a> {
1847 fn new(client: &'a Client) -> MultiReader<'a> {
1848 MultiReader { client, buf: Default::default() }
1849 }
1850
1851 pub fn add_get_data(&mut self, path: &str) -> Result<()> {
1855 let chroot_path = self.client.validate_path(path)?;
1856 let request = GetRequest { path: chroot_path, watch: false };
1857 self.add_operation(OpCode::GetData, &request);
1858 Ok(())
1859 }
1860
1861 pub fn add_get_children(&mut self, path: &str) -> Result<()> {
1865 let chroot_path = self.client.validate_path(path)?;
1866 let request = GetChildrenRequest { path: chroot_path, watch: false };
1867 self.add_operation(OpCode::GetChildren, &request);
1868 Ok(())
1869 }
1870
1871 pub fn commit(&mut self) -> impl Future<Output = Result<Vec<MultiReadResult>>> + Send + 'a {
1876 let request = self.build_request();
1877 Client::resolve(self.commit_internally(request))
1878 }
1879
1880 fn commit_internally(
1881 &self,
1882 request: MarshalledRequest,
1883 ) -> Result<Either<impl Future<Output = Result<Vec<MultiReadResult>>> + Send + 'a, Vec<MultiReadResult>>> {
1884 if request.is_empty() {
1885 return Ok(Right(Vec::default()));
1886 }
1887 let receiver = self.client.send_marshalled_request(request);
1888 Ok(Left(async move {
1889 let (body, _) = receiver.await?;
1890 let response = record::unmarshal::<Vec<MultiReadResponse>>(&mut body.as_slice())?;
1891 let mut results = Vec::with_capacity(response.len());
1892 for result in response {
1893 match result {
1894 MultiReadResponse::Data { data, stat } => results.push(MultiReadResult::Data { data, stat }),
1895 MultiReadResponse::Children { children } => results.push(MultiReadResult::Children { children }),
1896 MultiReadResponse::Error(err) => results.push(MultiReadResult::Error { err }),
1897 }
1898 }
1899 Ok(results)
1900 }))
1901 }
1902
1903 pub fn abort(&mut self) {
1905 self.buf.clear();
1906 }
1907}
1908
1909#[non_exhaustive]
1911#[derive(Debug, PartialEq, Eq)]
1912pub enum MultiWriteResult {
1913 Check,
1915
1916 Delete,
1918
1919 Create {
1921 path: String,
1923
1924 stat: Stat,
1931 },
1932
1933 SetData {
1935 stat: Stat,
1937 },
1938}
1939
1940impl MultiWriteResult {
1941 fn kind(&self) -> &'static str {
1942 match self {
1943 MultiWriteResult::Check => "MultiWriteResult::Check",
1944 MultiWriteResult::Create { .. } => "MultiWriteResult::Create",
1945 MultiWriteResult::Delete => "MultiWriteResult::Delete",
1946 MultiWriteResult::SetData { .. } => "MultiWriteResult::SetData",
1947 }
1948 }
1949
1950 fn into_check(self) -> Result<()> {
1951 match self {
1952 MultiWriteResult::Check => Ok(()),
1953 _ => Err(Error::UnexpectedError(format!("expect MultiWriteResult::Check, got {}", self.kind()))),
1954 }
1955 }
1956
1957 fn into_create(self) -> Result<(String, Stat)> {
1958 match self {
1959 MultiWriteResult::Create { path, stat } => Ok((path, stat)),
1960 _ => Err(Error::UnexpectedError(format!("expect MultiWriteResult::Create, got {}", self.kind()))),
1961 }
1962 }
1963
1964 fn into_set_data(self) -> Result<Stat> {
1965 match self {
1966 MultiWriteResult::SetData { stat } => Ok(stat),
1967 _ => Err(Error::UnexpectedError(format!("expect MultiWriteResult::SetData, got {}", self.kind()))),
1968 }
1969 }
1970
1971 fn into_delete(self) -> Result<()> {
1972 match self {
1973 MultiWriteResult::Delete => Ok(()),
1974 _ => Err(Error::UnexpectedError(format!("expect MultiWriteResult::Delete, got {}", self.kind()))),
1975 }
1976 }
1977}
1978
1979#[derive(Error, Clone, Debug, PartialEq, Eq)]
1981pub enum MultiWriteError {
1982 #[error("{source}")]
1983 RequestFailed {
1984 #[from]
1985 source: Error,
1986 },
1987
1988 #[error("operation at index {index} failed: {source}")]
1989 OperationFailed { index: usize, source: Error },
1990}
1991
1992impl From<MultiWriteError> for Error {
1993 fn from(err: MultiWriteError) -> Self {
1994 match err {
1995 MultiWriteError::RequestFailed { source } => source,
1996 MultiWriteError::OperationFailed { source, .. } => source,
1997 }
1998 }
1999}
2000
2001#[derive(Error, Clone, Debug, PartialEq, Eq)]
2003pub enum CheckWriteError {
2004 #[error("request failed: {source}")]
2005 RequestFailed {
2006 #[from]
2007 source: Error,
2008 },
2009
2010 #[error("path check failed: {source}")]
2011 CheckFailed { source: Error },
2012
2013 #[error("operation at index {index} failed: {source}")]
2014 OperationFailed { index: usize, source: Error },
2015}
2016
2017impl From<MultiWriteError> for CheckWriteError {
2018 fn from(err: MultiWriteError) -> Self {
2019 match err {
2020 MultiWriteError::RequestFailed { source } => CheckWriteError::RequestFailed { source },
2021 MultiWriteError::OperationFailed { index: 0, source } => CheckWriteError::CheckFailed { source },
2022 MultiWriteError::OperationFailed { index, source } => {
2023 CheckWriteError::OperationFailed { index: index - 1, source }
2024 },
2025 }
2026 }
2027}
2028
2029impl From<CheckWriteError> for Error {
2030 fn from(err: CheckWriteError) -> Self {
2031 match err {
2032 CheckWriteError::RequestFailed { source } => source,
2033 CheckWriteError::CheckFailed { source: Error::NoNode | Error::BadVersion } => Error::RuntimeInconsistent,
2034 CheckWriteError::CheckFailed { source } => source,
2035 CheckWriteError::OperationFailed { source, .. } => source,
2036 }
2037 }
2038}
2039
2040pub struct CheckWriter<'a> {
2042 writer: MultiWriter<'a>,
2043}
2044
2045impl<'a> CheckWriter<'a> {
2046 pub fn add_check_version(&mut self, path: &str, version: i32) -> Result<()> {
2048 self.writer.add_check_version(path, version)
2049 }
2050
2051 pub fn add_create(&mut self, path: &str, data: &[u8], options: &CreateOptions<'_>) -> Result<()> {
2053 self.writer.add_create(path, data, options)
2054 }
2055
2056 pub fn add_set_data(&mut self, path: &str, data: &[u8], expected_version: Option<i32>) -> Result<()> {
2058 self.writer.add_set_data(path, data, expected_version)
2059 }
2060
2061 pub fn add_delete(&mut self, path: &str, expected_version: Option<i32>) -> Result<()> {
2063 self.writer.add_delete(path, expected_version)
2064 }
2065
2066 pub fn commit(
2068 mut self,
2069 ) -> impl Future<Output = std::result::Result<Vec<MultiWriteResult>, CheckWriteError>> + Send + 'a {
2070 let commit = self.writer.commit();
2071 async move {
2072 let mut results = commit.await?;
2073 if results.is_empty() {
2074 Err(CheckWriteError::RequestFailed {
2075 source: Error::UnexpectedError("expect path check, got none".to_string()),
2076 })
2077 } else {
2078 results.remove(0).into_check()?;
2079 Ok(results)
2080 }
2081 }
2082 }
2083}
2084
2085pub struct MultiWriter<'a> {
2087 client: &'a Client,
2088 buf: Vec<u8>,
2089}
2090
2091impl MultiBuffer for MultiWriter<'_> {
2092 fn buffer(&mut self) -> &mut Vec<u8> {
2093 &mut self.buf
2094 }
2095
2096 fn op_code() -> OpCode {
2097 OpCode::Multi
2098 }
2099}
2100
2101impl<'a> MultiWriter<'a> {
2102 fn new(client: &'a Client) -> MultiWriter<'a> {
2103 MultiWriter { client, buf: Default::default() }
2104 }
2105
2106 pub fn add_check_version(&mut self, path: &str, version: i32) -> Result<()> {
2111 let chroot_path = self.client.validate_path(path)?;
2112 let request = CheckVersionRequest { path: chroot_path, version };
2113 self.add_operation(OpCode::Check, &request);
2114 Ok(())
2115 }
2116
2117 pub fn add_create(&mut self, path: &str, data: &[u8], options: &CreateOptions<'_>) -> Result<()> {
2128 options.validate()?;
2129 let ttl = options.ttl.map(|ttl| ttl.as_millis() as i64).unwrap_or(0);
2130 let create_mode = options.mode;
2131 let sequential = create_mode.is_sequential();
2132 let chroot_path =
2133 if sequential { self.client.validate_sequential_path(path)? } else { self.client.validate_path(path)? };
2134 let op_code = if ttl != 0 {
2135 OpCode::CreateTtl
2136 } else if create_mode.is_container() {
2137 OpCode::CreateContainer
2138 } else {
2139 OpCode::Create2
2140 };
2141 let flags = create_mode.as_flags(ttl != 0);
2142 let request = CreateRequest { path: chroot_path, data, acls: options.acls, flags, ttl };
2143 self.add_operation(op_code, &request);
2144 Ok(())
2145 }
2146
2147 pub fn add_set_data(&mut self, path: &str, data: &[u8], expected_version: Option<i32>) -> Result<()> {
2151 let chroot_path = self.client.validate_path(path)?;
2152 let request = SetDataRequest { path: chroot_path, data, version: expected_version.unwrap_or(-1) };
2153 self.add_operation(OpCode::SetData, &request);
2154 Ok(())
2155 }
2156
2157 pub fn add_delete(&mut self, path: &str, expected_version: Option<i32>) -> Result<()> {
2161 let chroot_path = self.client.validate_path(path)?;
2162 if chroot_path.is_root() {
2163 return Err(Error::BadArguments(&"can not delete root node"));
2164 }
2165 let request = DeleteRequest { path: chroot_path, version: expected_version.unwrap_or(-1) };
2166 self.add_operation(OpCode::Delete, &request);
2167 Ok(())
2168 }
2169
2170 pub fn commit(
2178 &mut self,
2179 ) -> impl Future<Output = std::result::Result<Vec<MultiWriteResult>, MultiWriteError>> + Send + 'a {
2180 let request = self.build_request();
2181 Client::resolve(self.commit_internally(request))
2182 }
2183
2184 #[allow(clippy::type_complexity)]
2185 fn commit_internally(
2186 &self,
2187 request: MarshalledRequest,
2188 ) -> Result<
2189 Either<impl Future<Output = Result<Vec<MultiWriteResult>, MultiWriteError>> + Send + 'a, Vec<MultiWriteResult>>,
2190 MultiWriteError,
2191 > {
2192 if request.is_empty() {
2193 return Ok(Right(Vec::default()));
2194 }
2195 let receiver = self.client.send_marshalled_request(request);
2196 let client = self.client;
2197 Ok(Left(async move {
2198 let (body, _) = receiver.await?;
2199 let response = record::unmarshal::<Vec<MultiWriteResponse>>(&mut body.as_slice())?;
2200 let failed = response.first().map(|r| matches!(r, MultiWriteResponse::Error(_))).unwrap_or(false);
2201 let mut results = if failed { Vec::new() } else { Vec::with_capacity(response.len()) };
2202 for (index, result) in response.into_iter().enumerate() {
2203 match result {
2204 MultiWriteResponse::Check => results.push(MultiWriteResult::Check),
2205 MultiWriteResponse::Delete => results.push(MultiWriteResult::Delete),
2206 MultiWriteResponse::Create { mut path, stat } => {
2207 path = util::strip_root_path(path, client.chroot.root())?;
2208 results.push(MultiWriteResult::Create { path: path.to_string(), stat });
2209 },
2210 MultiWriteResponse::SetData { stat } => results.push(MultiWriteResult::SetData { stat }),
2211 MultiWriteResponse::Error(Error::UnexpectedErrorCode(0)) => {},
2212 MultiWriteResponse::Error(err) => {
2213 return Err(MultiWriteError::OperationFailed { index, source: err })
2214 },
2215 }
2216 }
2217 Ok(results)
2218 }))
2219 }
2220
2221 pub fn abort(&mut self) {
2223 self.buf.clear();
2224 }
2225}
2226
2227#[cfg(test)]
2228mod tests {
2229 use assertor::*;
2230
2231 use super::*;
2232
2233 #[test]
2234 fn test_create_options_validate() {
2235 assert_that!(CreateMode::Persistent.with_acls(Acls::new(Default::default())).validate().unwrap_err())
2236 .is_equal_to(Error::InvalidAcl);
2237
2238 let acls = Acls::anyone_all();
2239
2240 assert_that!(CreateMode::Ephemeral.with_acls(acls).with_ttl(Duration::from_secs(1)).validate().unwrap_err())
2241 .is_equal_to(Error::BadArguments(&"ttl can only be specified with persistent node"));
2242
2243 assert_that!(CreateMode::Persistent.with_acls(acls).with_ttl(Duration::ZERO).validate().unwrap_err())
2244 .is_equal_to(Error::BadArguments(&"ttl is zero"));
2245
2246 assert_that!(CreateMode::Persistent
2247 .with_acls(acls)
2248 .with_ttl(Duration::from_millis(0x01FFFFFFFFFF))
2249 .validate()
2250 .unwrap_err())
2251 .is_equal_to(Error::BadArguments(&"ttl cannot larger than 1099511627775"));
2252
2253 assert_that!(CreateMode::Persistent.with_acls(acls).with_ttl(Duration::from_secs(5)).validate())
2254 .is_equal_to(Ok(()));
2255 }
2256
2257 #[test]
2258 fn test_lock_options_with_ancestor_options() {
2259 let options = LockOptions::new(Acls::anyone_all());
2260 assert_that!(options
2261 .clone()
2262 .with_ancestor_options(CreateMode::Ephemeral.with_acls(Acls::anyone_all()))
2263 .unwrap_err())
2264 .is_equal_to(Error::BadArguments(&"directory node must not be ephemeral"));
2265 assert_that!(options
2266 .with_ancestor_options(CreateMode::PersistentSequential.with_acls(Acls::anyone_all()))
2267 .unwrap_err())
2268 .is_equal_to(Error::BadArguments(&"directory node must not be sequential"));
2269 }
2270
2271 #[test_log::test(asyncs::test)]
2272 async fn session_last_zxid_seen() {
2273 use testcontainers::clients::Cli as DockerCli;
2274 use testcontainers::core::{Healthcheck, WaitFor};
2275 use testcontainers::images::generic::GenericImage;
2276
2277 let healthcheck = Healthcheck::default()
2278 .with_cmd(["./bin/zkServer.sh", "status"].iter())
2279 .with_interval(Duration::from_secs(2))
2280 .with_retries(60);
2281 let image =
2282 GenericImage::new("zookeeper", "3.9.0").with_healthcheck(healthcheck).with_wait_for(WaitFor::Healthcheck);
2283 let docker = DockerCli::default();
2284 let container = docker.run(image);
2285 let endpoint = format!("127.0.0.1:{}", container.get_host_port(2181));
2286
2287 let client1 = Client::connector().detached().connect(&endpoint).await.unwrap();
2288 client1.create("/n1", b"", &CreateMode::Persistent.with_acls(Acls::anyone_all())).await.unwrap();
2289
2290 let mut session = client1.into_session();
2291
2292 session.last_zxid = i64::MAX;
2294 assert_that!(Client::connector().fail_eagerly().session(session.clone()).connect(&endpoint).await.unwrap_err())
2295 .is_equal_to(Error::NoHosts);
2296
2297 session.last_zxid = 0;
2299 let client2 = Client::connector().fail_eagerly().session(session.clone()).connect(&endpoint).await.unwrap();
2300 client2.create("/n2", b"", &CreateMode::Persistent.with_acls(Acls::anyone_all())).await.unwrap();
2301 }
2302}