1mod watcher;
2
3use std::borrow::Cow;
4use std::fmt::Write as _;
5use std::future::Future;
6use std::mem::ManuallyDrop;
7use std::time::Duration;
8
9use const_format::formatcp;
10use either::{Either, Left, Right};
11use futures::channel::mpsc;
12use ignore_result::Ignore;
13use thiserror::Error;
14use tracing::instrument;
15
16pub use self::watcher::{OneshotWatcher, PersistentWatcher, StateWatcher};
17use super::session::{Depot, MarshalledRequest, Session, SessionOperation, WatchReceiver};
18use crate::acl::{Acl, Acls, AuthUser};
19use crate::chroot::{Chroot, ChrootPath, OwnedChroot};
20use crate::endpoint::{self, IterableEndpoints};
21use crate::error::Error;
22use crate::proto::{
23 self,
24 AuthPacket,
25 CheckVersionRequest,
26 CreateRequest,
27 DeleteRequest,
28 ExistsRequest,
29 GetAclResponse,
30 GetChildren2Response,
31 GetChildrenRequest,
32 GetRequest,
33 MultiHeader,
34 MultiReadResponse,
35 MultiWriteResponse,
36 OpCode,
37 PersistentWatchRequest,
38 ReconfigRequest,
39 RequestBuffer,
40 RequestHeader,
41 SetAclRequest,
42 SetDataRequest,
43 SyncRequest,
44};
45pub use crate::proto::{EnsembleUpdate, Stat};
46use crate::record::{self, Record, StaticRecord};
47#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
48use crate::sasl::SaslOptions;
49use crate::session::StateReceiver;
50pub use crate::session::{EventType, SessionId, SessionInfo, SessionState, WatchedEvent};
51#[cfg(feature = "tls")]
52use crate::tls::TlsOptions;
53use crate::util;
54
55pub(crate) type Result<T, E = Error> = std::result::Result<T, E>;
56
57#[derive(Clone, Copy, Debug, PartialEq, Eq)]
60pub enum CreateMode {
61 Persistent,
62 PersistentSequential,
63 Ephemeral,
64 EphemeralSequential,
65 Container,
66}
67
68impl CreateMode {
69 pub const fn with_acls(self, acls: Acls<'_>) -> CreateOptions<'_> {
71 CreateOptions { mode: self, acls, ttl: None }
72 }
73
74 fn is_sequential(self) -> bool {
75 self == CreateMode::PersistentSequential || self == CreateMode::EphemeralSequential
76 }
77
78 fn is_persistent(self) -> bool {
79 self == Self::Persistent || self == Self::PersistentSequential
80 }
81
82 fn is_ephemeral(self) -> bool {
83 self == Self::Ephemeral || self == Self::EphemeralSequential
84 }
85
86 fn is_container(self) -> bool {
87 self == CreateMode::Container
88 }
89
90 fn as_flags(self, ttl: bool) -> i32 {
91 use CreateMode::*;
92 match self {
93 Persistent => {
94 if ttl {
95 5
96 } else {
97 0
98 }
99 },
100 PersistentSequential => {
101 if ttl {
102 6
103 } else {
104 2
105 }
106 },
107 Ephemeral => 1,
108 EphemeralSequential => 3,
109 Container => 4,
110 }
111 }
112}
113
114#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
116pub enum AddWatchMode {
117 Persistent,
119
120 PersistentRecursive,
122}
123
124impl From<AddWatchMode> for proto::AddWatchMode {
125 fn from(mode: AddWatchMode) -> proto::AddWatchMode {
126 match mode {
127 AddWatchMode::Persistent => proto::AddWatchMode::Persistent,
128 AddWatchMode::PersistentRecursive => proto::AddWatchMode::PersistentRecursive,
129 }
130 }
131}
132
133#[derive(Clone, Debug)]
135pub struct CreateOptions<'a> {
136 mode: CreateMode,
137 acls: Acls<'a>,
138 ttl: Option<Duration>,
139}
140
141const TTL_MAX_MILLIS: u128 = 0x00FFFFFFFFFF;
145
146impl<'a> CreateOptions<'a> {
147 pub const fn with_ttl(mut self, ttl: Duration) -> Self {
149 self.ttl = Some(ttl);
150 self
151 }
152
153 fn validate(&'a self) -> Result<()> {
154 if let Some(ref ttl) = self.ttl {
155 if !self.mode.is_persistent() {
156 return Err(Error::BadArguments(&"ttl can only be specified with persistent node"));
157 } else if ttl.is_zero() {
158 return Err(Error::BadArguments(&"ttl is zero"));
159 } else if ttl.as_millis() > TTL_MAX_MILLIS {
160 return Err(Error::BadArguments(&formatcp!("ttl cannot larger than {}", TTL_MAX_MILLIS)));
161 }
162 }
163 if self.acls.is_empty() {
164 return Err(Error::InvalidAcl);
165 }
166 Ok(())
167 }
168
169 fn validate_as_directory(&self) -> Result<()> {
170 self.validate()?;
171 if self.mode.is_ephemeral() {
172 return Err(Error::BadArguments(&"directory node must not be ephemeral"));
173 } else if self.mode.is_sequential() {
174 return Err(Error::BadArguments(&"directory node must not be sequential"));
175 }
176 Ok(())
177 }
178}
179
180#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
184pub struct CreateSequence(i64);
185
186impl std::fmt::Display for CreateSequence {
187 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188 if self.0 <= i32::MAX.into() {
192 write!(f, "{:010}", self.0)
193 } else {
194 write!(f, "{:019}", self.0)
195 }
196 }
197}
198
199impl CreateSequence {
200 pub fn into_i64(self) -> i64 {
201 self.0
202 }
203}
204
205#[derive(Clone, Debug)]
219pub struct Client {
220 chroot: OwnedChroot,
221 version: Version,
222 session: SessionInfo,
223 session_timeout: Duration,
224 requester: mpsc::UnboundedSender<SessionOperation>,
225 state_watcher: StateWatcher,
226}
227
228impl Client {
229 const CONFIG_NODE: &'static str = "/zookeeper/config";
230
231 pub async fn connect(cluster: &str) -> Result<Self> {
233 Self::connector().connect(cluster).await
234 }
235
236 #[deprecated(since = "0.7.0", note = "use Client::connector instead")]
238 pub fn builder() -> ClientBuilder {
239 ClientBuilder::new()
240 }
241
242 pub fn connector() -> Connector {
244 Connector::new()
245 }
246
247 pub(crate) fn new(
248 chroot: OwnedChroot,
249 version: Version,
250 session: SessionInfo,
251 timeout: Duration,
252 requester: mpsc::UnboundedSender<SessionOperation>,
253 state_watcher: StateWatcher,
254 ) -> Client {
255 Client { chroot, version, session, session_timeout: timeout, requester, state_watcher }
256 }
257
258 fn validate_path<'a>(&'a self, path: &'a str) -> Result<ChrootPath<'a>> {
259 ChrootPath::new(self.chroot.as_ref(), path, false)
260 }
261
262 fn validate_sequential_path<'a>(&'a self, path: &'a str) -> Result<ChrootPath<'a>> {
263 ChrootPath::new(self.chroot.as_ref(), path, true)
264 }
265
266 pub fn path(&self) -> &str {
268 self.chroot.path()
269 }
270
271 pub fn session(&self) -> &SessionInfo {
273 &self.session
274 }
275
276 pub fn session_id(&self) -> SessionId {
278 self.session().id()
279 }
280
281 pub fn into_session(self) -> SessionInfo {
283 self.session
284 }
285
286 pub fn session_timeout(&self) -> Duration {
288 self.session_timeout
289 }
290
291 pub fn state(&self) -> SessionState {
293 self.state_watcher.peek_state()
294 }
295
296 pub fn state_watcher(&self) -> StateWatcher {
298 let mut watcher = self.state_watcher.clone();
299 watcher.state();
300 watcher
301 }
302
303 pub fn chroot<'a>(mut self, path: impl Into<Cow<'a, str>>) -> std::result::Result<Client, Client> {
311 if self.chroot.chroot(path) {
312 Ok(self)
313 } else {
314 Err(self)
315 }
316 }
317
318 fn send_request(&self, code: OpCode, body: &impl Record) -> StateReceiver {
319 let request = MarshalledRequest::new(code, body);
320 self.send_marshalled_request(request)
321 }
322
323 fn send_marshalled_request(&self, request: MarshalledRequest) -> StateReceiver {
324 let (operation, receiver) = SessionOperation::new_marshalled(request).with_responser();
325 if let Err(err) = self.requester.unbounded_send(operation) {
326 let state = self.state();
327 err.into_inner().responser.send(Err(state.to_error()));
328 }
329 receiver
330 }
331
332 async fn wait<T, E, F>(result: std::result::Result<F, E>) -> std::result::Result<T, E>
333 where
334 F: Future<Output = std::result::Result<T, E>>, {
335 match result {
336 Err(err) => Err(err),
337 Ok(future) => future.await,
338 }
339 }
340
341 async fn resolve<T, E, F>(result: std::result::Result<Either<F, T>, E>) -> std::result::Result<T, E>
342 where
343 F: Future<Output = std::result::Result<T, E>>, {
344 match result {
345 Err(err) => Err(err),
346 Ok(Right(r)) => Ok(r),
347 Ok(Left(future)) => future.await,
348 }
349 }
350
351 async fn map_wait<T, U, Fu, Fn>(result: Result<Fu>, f: Fn) -> Result<U>
352 where
353 Fu: Future<Output = Result<T>>,
354 Fn: FnOnce(T) -> U, {
355 match result {
356 Err(err) => Err(err),
357 Ok(future) => match future.await {
358 Err(err) => Err(err),
359 Ok(t) => Ok(f(t)),
360 },
361 }
362 }
363
364 async fn retry_on_connection_loss<T, F>(operation: impl Fn() -> F) -> Result<T>
365 where
366 F: Future<Output = Result<T>>, {
367 loop {
368 let future = operation();
369 return match future.await {
370 Err(Error::ConnectionLoss) => continue,
371 result => result,
372 };
373 }
374 }
375
376 fn parse_sequence(client_path: &str, prefix: &str) -> Result<CreateSequence> {
377 if let Some(sequence_path) = client_path.strip_prefix(prefix) {
378 match sequence_path.parse::<i64>() {
379 Err(_) => Err(Error::UnexpectedError(format!("sequential node get no i32 path {client_path}"))),
380 Ok(i) => Ok(CreateSequence(i)),
381 }
382 } else {
383 Err(Error::UnexpectedError(format!("sequential path {client_path} does not contain prefix path {prefix}",)))
384 }
385 }
386
387 pub async fn mkdir(&self, path: &str, options: &CreateOptions<'_>) -> Result<()> {
399 options.validate_as_directory()?;
400 self.mkdir_internally(path, options).await
401 }
402
403 async fn mkdir_internally(&self, path: &str, options: &CreateOptions<'_>) -> Result<()> {
404 let mut j = path.len();
405 loop {
406 match self.create(&path[..j], Default::default(), options).await {
407 Ok(_) | Err(Error::NodeExists) => {
408 if j >= path.len() {
409 return Ok(());
410 } else if let Some(i) = path[j + 1..].find('/') {
411 j = j + 1 + i;
412 } else {
413 j = path.len();
414 }
415 },
416 Err(Error::NoNode) => {
417 let i = path[..j].rfind('/').unwrap();
418 if i == 0 {
419 return Err(Error::NoNode);
421 }
422 j = i;
423 },
424 Err(err) => return Err(err),
425 }
426 }
427 }
428
429 pub fn create<'a: 'f, 'b: 'f, 'f>(
443 &'a self,
444 path: &'b str,
445 data: &[u8],
446 options: &CreateOptions<'_>,
447 ) -> impl Future<Output = Result<(Stat, CreateSequence)>> + Send + 'f {
448 Self::wait(self.create_internally(path, data, options))
449 }
450
451 fn create_internally<'a: 'f, 'b: 'f, 'f>(
452 &'a self,
453 path: &'b str,
454 data: &[u8],
455 options: &CreateOptions<'_>,
456 ) -> Result<impl Future<Output = Result<(Stat, CreateSequence)>> + Send + 'f> {
457 options.validate()?;
458 let create_mode = options.mode;
459 let sequential = create_mode.is_sequential();
460 let chroot_path = if sequential { self.validate_sequential_path(path)? } else { self.validate_path(path)? };
461 if chroot_path.is_root() {
462 return Err(Error::BadArguments(&"can not create root node"));
463 }
464 let ttl = options.ttl.map(|ttl| ttl.as_millis() as i64).unwrap_or(0);
465 let op_code = if ttl != 0 {
466 OpCode::CreateTtl
467 } else if create_mode.is_container() {
468 OpCode::CreateContainer
469 } else if self.version >= Version(3, 5, 0) {
470 OpCode::Create2
471 } else {
472 OpCode::Create
473 };
474 let flags = create_mode.as_flags(ttl != 0);
475 let request = CreateRequest { path: chroot_path, data, acls: options.acls, flags, ttl };
476 let receiver = self.send_request(op_code, &request);
477 Ok(async move {
478 let (body, _) = receiver.await?;
479 let mut buf = body.as_slice();
480 let server_path = record::unmarshal_entity::<&str>(&"server path", &mut buf)?;
481 let client_path = util::strip_root_path(server_path, self.chroot.root())?;
482 let sequence = if sequential { Self::parse_sequence(client_path, path)? } else { CreateSequence(-1) };
483 let stat =
484 if op_code == OpCode::Create { Stat::new_invalid() } else { record::unmarshal::<Stat>(&mut buf)? };
485 Ok((stat, sequence))
486 })
487 }
488
489 pub fn delete(&self, path: &str, expected_version: Option<i32>) -> impl Future<Output = Result<()>> + Send {
496 Self::wait(self.delete_internally(path, expected_version))
497 }
498
499 fn delete_internally(&self, path: &str, expected_version: Option<i32>) -> Result<impl Future<Output = Result<()>>> {
500 let chroot_path = self.validate_path(path)?;
501 if chroot_path.is_root() {
502 return Err(Error::BadArguments(&"can not delete root node"));
503 }
504 let request = DeleteRequest { path: chroot_path, version: expected_version.unwrap_or(-1) };
505 let receiver = self.send_request(OpCode::Delete, &request);
506 Ok(async move {
507 receiver.await?;
508 Ok(())
509 })
510 }
511
512 fn delete_background(self, path: String) {
514 asyncs::spawn(async move {
515 self.delete_foreground(&path).await;
516 });
517 }
518
519 async fn delete_foreground(&self, path: &str) {
520 Client::retry_on_connection_loss(|| self.delete(path, None)).await.ignore();
521 }
522
523 fn delete_ephemeral_background(self, prefix: String, unique: bool) {
524 asyncs::spawn(async move {
525 let (parent, tree, name) = util::split_path(&prefix);
526 let mut children = Self::retry_on_connection_loss(|| self.list_children(parent)).await?;
527 if unique {
528 if let Some(i) = children.iter().position(|s| s.starts_with(name)) {
529 self.delete_foreground(&children[i]).await;
530 };
531 return Ok::<(), Error>(());
532 }
533 children.retain(|s| s.starts_with(name));
534 for child in children.iter_mut() {
535 child.insert_str(0, tree);
536 }
537 let results = Self::retry_on_connection_loss(|| {
538 let mut reader = self.new_multi_reader();
539 for child in children.iter() {
540 reader.add_get_data(child).unwrap();
541 }
542 reader.commit()
543 })
544 .await?;
545 for (i, result) in results.into_iter().enumerate() {
546 let MultiReadResult::Data { stat, .. } = result else {
547 continue;
549 };
550 if stat.ephemeral_owner == self.session_id().0 {
551 self.delete_foreground(&children[i]).await;
552 break;
553 }
554 }
555 Ok(())
556 });
557 }
558
559 fn get_data_internally(
560 &self,
561 chroot: Chroot,
562 path: &str,
563 watch: bool,
564 ) -> Result<impl Future<Output = Result<(Vec<u8>, Stat, WatchReceiver)>> + Send> {
565 let chroot_path = ChrootPath::new(chroot, path, false)?;
566 let request = GetRequest { path: chroot_path, watch };
567 let receiver = self.send_request(OpCode::GetData, &request);
568 Ok(async move {
569 let (mut body, watcher) = receiver.await?;
570 let data_len = body.len() - Stat::record_len();
571 let mut stat_buf = &body[data_len..];
572 let stat = record::unmarshal(&mut stat_buf)?;
573 body.truncate(data_len);
574 drop(body.drain(..4));
575 Ok((body, stat, watcher))
576 })
577 }
578
579 pub fn get_data(&self, path: &str) -> impl Future<Output = Result<(Vec<u8>, Stat)>> + Send {
584 let result = self.get_data_internally(self.chroot.as_ref(), path, false);
585 Self::map_wait(result, |(data, stat, _)| (data, stat))
586 }
587
588 pub fn get_and_watch_data(
598 &self,
599 path: &str,
600 ) -> impl Future<Output = Result<(Vec<u8>, Stat, OneshotWatcher)>> + Send + '_ {
601 let result = self.get_data_internally(self.chroot.as_ref(), path, true);
602 Self::map_wait(result, |(data, stat, watcher)| (data, stat, watcher.into_oneshot(&self.chroot)))
603 }
604
605 fn check_stat_internally(
606 &self,
607 path: &str,
608 watch: bool,
609 ) -> Result<impl Future<Output = Result<(Option<Stat>, WatchReceiver)>>> {
610 let chroot_path = self.validate_path(path)?;
611 let request = ExistsRequest { path: chroot_path, watch };
612 let receiver = self.send_request(OpCode::Exists, &request);
613 Ok(async move {
614 let (body, watcher) = receiver.await?;
615 let mut buf = body.as_slice();
616 let stat = record::try_deserialize(&mut buf)?;
617 Ok((stat, watcher))
618 })
619 }
620
621 pub fn check_stat(&self, path: &str) -> impl Future<Output = Result<Option<Stat>>> + Send {
623 Self::map_wait(self.check_stat_internally(path, false), |(stat, _)| stat)
624 }
625
626 pub fn check_and_watch_stat(
633 &self,
634 path: &str,
635 ) -> impl Future<Output = Result<(Option<Stat>, OneshotWatcher)>> + Send + '_ {
636 let result = self.check_stat_internally(path, true);
637 Self::map_wait(result, |(stat, watcher)| (stat, watcher.into_oneshot(&self.chroot)))
638 }
639
640 pub fn set_data(
647 &self,
648 path: &str,
649 data: &[u8],
650 expected_version: Option<i32>,
651 ) -> impl Future<Output = Result<Stat>> + Send {
652 Self::wait(self.set_data_internally(path, data, expected_version))
653 }
654
655 pub fn set_data_internally(
656 &self,
657 path: &str,
658 data: &[u8],
659 expected_version: Option<i32>,
660 ) -> Result<impl Future<Output = Result<Stat>>> {
661 let chroot_path = self.validate_path(path)?;
662 let request = SetDataRequest { path: chroot_path, data, version: expected_version.unwrap_or(-1) };
663 let receiver = self.send_request(OpCode::SetData, &request);
664 Ok(async move {
665 let (body, _) = receiver.await?;
666 let mut buf = body.as_slice();
667 let stat: Stat = record::unmarshal(&mut buf)?;
668 Ok(stat)
669 })
670 }
671
672 fn list_children_internally(
673 &self,
674 path: &str,
675 watch: bool,
676 ) -> Result<impl Future<Output = Result<(Vec<String>, WatchReceiver)>>> {
677 let chroot_path = self.validate_path(path)?;
678 let request = GetChildrenRequest { path: chroot_path, watch };
679 let receiver = self.send_request(OpCode::GetChildren, &request);
680 Ok(async move {
681 let (body, watcher) = receiver.await?;
682 let mut buf = body.as_slice();
683 let children = record::unmarshal_entity::<Vec<String>>(&"children paths", &mut buf)?;
684 Ok((children, watcher))
685 })
686 }
687
688 pub fn list_children(&self, path: &str) -> impl Future<Output = Result<Vec<String>>> + Send + '_ {
693 Self::map_wait(self.list_children_internally(path, false), |(children, _)| children)
694 }
695
696 pub fn list_and_watch_children(
707 &self,
708 path: &str,
709 ) -> impl Future<Output = Result<(Vec<String>, OneshotWatcher)>> + Send + '_ {
710 let result = self.list_children_internally(path, true);
711 Self::map_wait(result, |(children, watcher)| (children, watcher.into_oneshot(&self.chroot)))
712 }
713
714 fn get_children_internally(
715 &self,
716 path: &str,
717 watch: bool,
718 ) -> Result<impl Future<Output = Result<(Vec<String>, Stat, WatchReceiver)>>> {
719 let chroot_path = self.validate_path(path)?;
720 let request = GetChildrenRequest { path: chroot_path, watch };
721 let receiver = self.send_request(OpCode::GetChildren2, &request);
722 Ok(async move {
723 let (body, watcher) = receiver.await?;
724 let mut buf = body.as_slice();
725 let response = record::unmarshal::<GetChildren2Response>(&mut buf)?;
726 Ok((response.children, response.stat, watcher))
727 })
728 }
729
730 pub fn get_children(&self, path: &str) -> impl Future<Output = Result<(Vec<String>, Stat)>> + Send {
735 let result = self.get_children_internally(path, false);
736 Self::map_wait(result, |(children, stat, _)| (children, stat))
737 }
738
739 pub fn get_and_watch_children(
750 &self,
751 path: &str,
752 ) -> impl Future<Output = Result<(Vec<String>, Stat, OneshotWatcher)>> + Send + '_ {
753 let result = self.get_children_internally(path, true);
754 Self::map_wait(result, |(children, stat, watcher)| (children, stat, watcher.into_oneshot(&self.chroot)))
755 }
756
757 pub fn count_descendants_number(&self, path: &str) -> impl Future<Output = Result<usize>> + Send {
762 Self::wait(self.count_descendants_number_internally(path))
763 }
764
765 fn count_descendants_number_internally(&self, path: &str) -> Result<impl Future<Output = Result<usize>>> {
766 let chroot_path = self.validate_path(path)?;
767 let receiver = self.send_request(OpCode::GetAllChildrenNumber, &chroot_path);
768 Ok(async move {
769 let (body, _) = receiver.await?;
770 let mut buf = body.as_slice();
771 let n = record::unmarshal_entity::<i32>(&"all children number", &mut buf)?;
772 Ok(n as usize)
773 })
774 }
775
776 pub fn list_ephemerals(&self, path: &str) -> impl Future<Output = Result<Vec<String>>> + Send + '_ {
783 Self::wait(self.list_ephemerals_internally(path))
784 }
785
786 fn list_ephemerals_internally(&self, path: &str) -> Result<impl Future<Output = Result<Vec<String>>> + Send + '_> {
787 let path = self.validate_path(path)?;
788 let receiver = self.send_request(OpCode::GetEphemerals, &path);
789 Ok(async move {
790 let (body, _) = receiver.await?;
791 let mut buf = body.as_slice();
792 let mut ephemerals = record::unmarshal_entity::<Vec<String>>(&"ephemerals", &mut buf)?;
793 for ephemeral_path in ephemerals.iter_mut() {
794 util::drain_root_path(ephemeral_path, self.chroot.root())?;
795 }
796 Ok(ephemerals)
797 })
798 }
799
800 pub fn get_acl(&self, path: &str) -> impl Future<Output = Result<(Vec<Acl>, Stat)>> + Send + '_ {
805 Self::wait(self.get_acl_internally(path))
806 }
807
808 fn get_acl_internally(&self, path: &str) -> Result<impl Future<Output = Result<(Vec<Acl>, Stat)>>> {
809 let chroot_path = self.validate_path(path)?;
810 let receiver = self.send_request(OpCode::GetACL, &chroot_path);
811 Ok(async move {
812 let (body, _) = receiver.await?;
813 let mut buf = body.as_slice();
814 let response: GetAclResponse = record::unmarshal(&mut buf)?;
815 Ok((response.acl, response.stat))
816 })
817 }
818
819 pub fn set_acl(
825 &self,
826 path: &str,
827 acl: &[Acl],
828 expected_acl_version: Option<i32>,
829 ) -> impl Future<Output = Result<Stat>> + Send + '_ {
830 Self::wait(self.set_acl_internally(path, acl, expected_acl_version))
831 }
832
833 fn set_acl_internally(
834 &self,
835 path: &str,
836 acl: &[Acl],
837 expected_acl_version: Option<i32>,
838 ) -> Result<impl Future<Output = Result<Stat>>> {
839 let chroot_path = self.validate_path(path)?;
840 let request = SetAclRequest { path: chroot_path, acl, version: expected_acl_version.unwrap_or(-1) };
841 let receiver = self.send_request(OpCode::SetACL, &request);
842 Ok(async move {
843 let (body, _) = receiver.await?;
844 let mut buf = body.as_slice();
845 let stat: Stat = record::unmarshal(&mut buf)?;
846 Ok(stat)
847 })
848 }
849
850 pub fn watch(&self, path: &str, mode: AddWatchMode) -> impl Future<Output = Result<PersistentWatcher>> + Send + '_ {
865 Self::wait(self.watch_internally(path, mode))
866 }
867
868 fn watch_internally(
869 &self,
870 path: &str,
871 mode: AddWatchMode,
872 ) -> Result<impl Future<Output = Result<PersistentWatcher>> + Send + '_> {
873 let chroot_path = self.validate_path(path)?;
874 let proto_mode = proto::AddWatchMode::from(mode);
875 let request = PersistentWatchRequest { path: chroot_path, mode: proto_mode.into() };
876 let receiver = self.send_request(OpCode::AddWatch, &request);
877 Ok(async move {
878 let (_, watcher) = receiver.await?;
879 Ok(watcher.into_persistent(&self.chroot))
880 })
881 }
882
883 pub fn sync(&self, path: &str) -> impl Future<Output = Result<()>> + Send + '_ {
894 Self::wait(self.sync_internally(path))
895 }
896
897 fn sync_internally(&self, path: &str) -> Result<impl Future<Output = Result<()>>> {
898 let chroot_path = self.validate_path(path)?;
899 let request = SyncRequest { path: chroot_path };
900 let receiver = self.send_request(OpCode::Sync, &request);
901 Ok(async move {
902 let (body, _) = receiver.await?;
903 let mut buf = body.as_slice();
904 record::unmarshal_entity::<&str>(&"server path", &mut buf)?;
905 Ok(())
906 })
907 }
908
909 pub fn auth(&self, scheme: String, auth: Vec<u8>) -> impl Future<Output = Result<()>> + Send + '_ {
923 let request = AuthPacket { scheme, auth };
924 let receiver = self.send_request(OpCode::Auth, &request);
925 async move {
926 receiver.await?;
927 Ok(())
928 }
929 }
930
931 pub fn list_auth_users(&self) -> impl Future<Output = Result<Vec<AuthUser>>> + Send {
941 let receiver = self.send_request(OpCode::WhoAmI, &());
942 async move {
943 let (body, _) = receiver.await?;
944 let mut buf = body.as_slice();
945 let authed_users = record::unmarshal_entity::<Vec<AuthUser>>(&"authed users", &mut buf)?;
946 Ok(authed_users)
947 }
948 }
949
950 pub fn get_config(&self) -> impl Future<Output = Result<(Vec<u8>, Stat)>> + Send {
952 let result = self.get_data_internally(Chroot::default(), Self::CONFIG_NODE, false);
953 Self::map_wait(result, |(data, stat, _)| (data, stat))
954 }
955
956 pub fn get_and_watch_config(&self) -> impl Future<Output = Result<(Vec<u8>, Stat, OneshotWatcher)>> + Send {
958 let result = self.get_data_internally(Chroot::default(), Self::CONFIG_NODE, true);
959 Self::map_wait(result, |(data, stat, watcher)| (data, stat, watcher.into_oneshot(&OwnedChroot::default())))
960 }
961
962 pub fn update_ensemble<'a, I: Iterator<Item = &'a str> + Clone>(
970 &self,
971 update: EnsembleUpdate<'a, I>,
972 expected_zxid: Option<i64>,
973 ) -> impl Future<Output = Result<(Vec<u8>, Stat)>> + Send {
974 let request = ReconfigRequest { update, version: expected_zxid.unwrap_or(-1) };
975 let receiver = self.send_request(OpCode::Reconfig, &request);
976 async move {
977 let (mut body, _) = receiver.await?;
978 let mut buf = body.as_slice();
979 let data: &str = record::unmarshal_entity(&"reconfig data", &mut buf)?;
980 let stat = record::unmarshal_entity(&"reconfig stat", &mut buf)?;
981 let data_len = data.len();
982 body.truncate(data_len + 4);
983 drop(body.drain(..4));
984 Ok((body, stat))
985 }
986 }
987
988 pub fn new_multi_reader(&self) -> MultiReader<'_> {
990 MultiReader::new(self)
991 }
992
993 pub fn new_multi_writer(&self) -> MultiWriter<'_> {
995 MultiWriter::new(self)
996 }
997
998 pub fn new_check_writer(&self, path: &str, version: Option<i32>) -> Result<CheckWriter<'_>> {
1001 let mut writer = self.new_multi_writer();
1002 writer.add_check_version(path, version.unwrap_or(-1))?;
1003 Ok(CheckWriter { writer })
1004 }
1005
1006 async fn create_lock(
1007 &self,
1008 prefix: LockPrefix<'_>,
1009 data: &[u8],
1010 options: LockOptions<'_>,
1011 ) -> Result<(String, usize)> {
1012 let kind = prefix.kind();
1013 let prefix = prefix.into();
1014 self.validate_sequential_path(&prefix)?;
1015 let (parent, _, _) = util::split_path(&prefix);
1016 let guard = LockingGuard { zk: self, prefix: &prefix, unique: kind.is_unique() };
1017 loop {
1018 let mut result = self.create(&prefix, data, &CreateMode::EphemeralSequential.with_acls(options.acls)).await;
1019 if result == Err(Error::NoNode) {
1020 if let Some(options) = &options.parent {
1021 match Self::retry_on_connection_loss(|| self.mkdir_internally(parent, options)).await {
1022 Ok(_) => continue,
1023 Err(Error::NoNode) => result = Err(Error::NoNode),
1024 Err(err) => return Err(err),
1025 }
1026 }
1027 }
1028 let sequence = match result {
1029 Err(Error::ConnectionLoss) => {
1030 if let Some(sequence) = self.find_lock(&prefix, kind).await? {
1031 sequence
1032 } else {
1033 continue;
1034 }
1035 },
1036 Err(err) => {
1037 if err.has_no_data_change() {
1038 std::mem::forget(guard);
1039 return Err(err);
1040 } else {
1041 return Err(err);
1042 }
1043 },
1044 Ok((_stat, sequence)) => sequence,
1045 };
1046 std::mem::forget(guard);
1047 let prefix_len = prefix.len();
1048 let mut path = prefix;
1049 write!(&mut path, "{sequence}").unwrap();
1050 let sequence_len = path.len() - prefix_len;
1051 return Ok((path, sequence_len));
1052 }
1053 }
1054
1055 async fn find_lock(&self, prefix: &str, kind: LockPrefixKind<'_>) -> Result<Option<CreateSequence>> {
1056 let (parent, tree, name) = util::split_path(prefix);
1057 let mut children = Self::retry_on_connection_loss(|| self.list_children(parent)).await?;
1058 if kind.is_unique() {
1059 let Some(i) = children.iter().position(|s| s.starts_with(name)) else {
1060 return Ok(None);
1061 };
1062 let sequence = Self::parse_sequence(&children[i], name)?;
1063 return Ok(Some(sequence));
1064 }
1065 children.retain(|s| s.starts_with(name));
1066 if children.is_empty() {
1067 return Ok(None);
1068 }
1069 for child in children.iter_mut() {
1070 child.insert_str(0, tree);
1071 }
1072 let results = Self::retry_on_connection_loss(|| {
1073 let mut reader = self.new_multi_reader();
1074 for child in children.iter() {
1075 reader.add_get_data(child).unwrap();
1076 }
1077 reader.commit()
1078 })
1079 .await?;
1080 for (i, result) in results.into_iter().enumerate() {
1081 let MultiReadResult::Data { stat, .. } = result else {
1082 continue;
1084 };
1085 if stat.ephemeral_owner == self.session_id().0 {
1086 let sequence = Self::parse_sequence(&children[i], name)?;
1087 return Ok(Some(sequence));
1088 }
1089 }
1090 Ok(None)
1091 }
1092
1093 async fn wait_lock(&self, lock: &str, kind: LockPrefixKind<'_>, sequence_len: usize) -> Result<()> {
1094 let (parent, tree, this) = util::split_path(lock);
1095 loop {
1096 let mut children = Self::retry_on_connection_loss(|| self.list_children(parent)).await?;
1097 children.retain(|s| {
1098 s.len() >= sequence_len && kind.filter(s) && s[s.len() - sequence_len..].parse::<i32>().is_ok()
1099 });
1100 children.sort_unstable_by(|a, b| a[a.len() - sequence_len..].cmp(&b[b.len() - sequence_len..]));
1101 match children.binary_search_by(|a| a[a.len() - sequence_len..].cmp(&this[this.len() - sequence_len..])) {
1102 Ok(0) => return Ok(()),
1103 Ok(i) => {
1104 let mut child = children.swap_remove(i - 1);
1105 child.insert_str(0, tree);
1106 let watcher = match Self::retry_on_connection_loss(|| self.get_and_watch_data(&child)).await {
1107 Err(Error::NoNode) => continue,
1108 Err(err) => return Err(err),
1109 Ok((_data, _stat, watcher)) => watcher,
1110 };
1111 watcher.changed().await;
1112 },
1113 Err(_) => return Err(Error::RuntimeInconsistent),
1114 }
1115 }
1116 }
1117
1118 pub async fn lock(
1149 &self,
1150 prefix: LockPrefix<'_>,
1151 data: &[u8],
1152 options: impl Into<LockOptions<'_>>,
1153 ) -> Result<LockClient<'_>> {
1154 let options = options.into();
1155 if options.acls.is_empty() {
1156 return Err(Error::InvalidAcl);
1157 }
1158 let prefix_kind = prefix.kind();
1159 let (lock, sequence_len) = self.create_lock(prefix, data, options).await?;
1160 let client = LockClient { client: self, lock: Cow::from(lock) };
1161 match self.wait_lock(&client.lock, prefix_kind, sequence_len).await {
1162 Err(err @ (Error::RuntimeInconsistent | Error::SessionExpired)) => {
1163 std::mem::forget(client);
1164 Err(err)
1165 },
1166 Err(err) => Err(err),
1167 Ok(_) => Ok(client),
1168 }
1169 }
1170}
1171
1172#[derive(Clone, Debug)]
1175pub struct LockOptions<'a> {
1176 acls: Acls<'a>,
1177 parent: Option<CreateOptions<'a>>,
1178}
1179
1180impl<'a> LockOptions<'a> {
1181 pub fn new(acls: Acls<'a>) -> Self {
1182 Self { acls, parent: None }
1183 }
1184
1185 pub fn with_ancestor_options(mut self, options: CreateOptions<'a>) -> Result<Self> {
1191 options.validate_as_directory()?;
1192 self.parent = Some(options);
1193 Ok(self)
1194 }
1195}
1196
1197impl<'a> From<Acls<'a>> for LockOptions<'a> {
1198 fn from(acls: Acls<'a>) -> Self {
1199 LockOptions::new(acls)
1200 }
1201}
1202
1203#[derive(Clone, Copy)]
1204enum LockPrefixKind<'a> {
1205 Curator { lock_name: &'a str },
1206 Custom { lock_name: &'a str },
1207 Shared { prefix: &'a str },
1208}
1209
1210impl LockPrefixKind<'_> {
1211 fn filter(&self, name: &str) -> bool {
1212 match self {
1213 Self::Curator { lock_name } => name.contains(lock_name),
1214 Self::Custom { lock_name } => name.contains(lock_name),
1215 Self::Shared { prefix } => name.starts_with(prefix),
1216 }
1217 }
1218
1219 fn is_unique(&self) -> bool {
1220 matches!(self, Self::Curator { .. })
1221 }
1222}
1223
1224#[derive(Debug)]
1225enum LockPrefixInner<'a> {
1226 Curator { dir: &'a str, name: &'a str },
1227 Custom { prefix: String, name: &'a str },
1228 Shared { prefix: &'a str },
1229}
1230
1231#[derive(Debug)]
1240pub struct LockPrefix<'a> {
1241 inner: LockPrefixInner<'a>,
1242}
1243
1244impl<'a> LockPrefix<'a> {
1245 pub fn new_curator(dir: &'a str, name: &'a str) -> Result<Self> {
1252 crate::util::validate_path(Chroot::default(), dir, false)?;
1253 if name.find('/').is_some() {
1254 return Err(Error::BadArguments(&"lock name must not contain /"));
1255 }
1256 Ok(Self { inner: LockPrefixInner::Curator { dir, name } })
1257 }
1258
1259 pub fn new_shared(prefix: &'a str) -> Result<Self> {
1271 crate::util::validate_path(Chroot::default(), prefix, true)?;
1272 Ok(Self { inner: LockPrefixInner::Shared { prefix } })
1273 }
1274
1275 pub fn new_custom(prefix: String, name: &'a str) -> Result<Self> {
1291 crate::util::validate_path(Chroot::default(), &prefix, true)?;
1292 if !name.is_empty() {
1293 let (_dir, _tree, this) = util::split_path(&prefix);
1294 if !this.contains(name) {
1295 return Err(Error::BadArguments(&"lock path prefix must contain lock name"));
1296 }
1297 }
1298 Ok(Self { inner: LockPrefixInner::Custom { prefix, name } })
1299 }
1300
1301 fn kind(&self) -> LockPrefixKind<'a> {
1302 match &self.inner {
1303 LockPrefixInner::Curator { name, .. } => LockPrefixKind::Curator { lock_name: name },
1304 LockPrefixInner::Shared { prefix } => {
1305 let (_parent, _tree, name) = util::split_path(prefix);
1306 LockPrefixKind::Shared { prefix: name }
1307 },
1308 LockPrefixInner::Custom { name, .. } => LockPrefixKind::Custom { lock_name: name },
1309 }
1310 }
1311
1312 fn into(self) -> String {
1313 match self.inner {
1314 LockPrefixInner::Curator { dir, name } => format!("{}/_c_{}-{}", dir, uuid::Uuid::new_v4(), name),
1315 LockPrefixInner::Shared { prefix } => prefix.to_string(),
1316 LockPrefixInner::Custom { prefix, .. } => prefix,
1317 }
1318 }
1319}
1320
1321struct LockingGuard<'a> {
1322 zk: &'a Client,
1323 prefix: &'a str,
1324 unique: bool,
1325}
1326
1327impl Drop for LockingGuard<'_> {
1328 fn drop(&mut self) {
1329 self.zk.clone().delete_ephemeral_background(self.prefix.to_string(), self.unique);
1330 }
1331}
1332
1333#[derive(Debug)]
1335pub struct LockClient<'a> {
1336 client: &'a Client,
1337 lock: Cow<'a, str>,
1338}
1339
1340impl<'a> LockClient<'a> {
1341 async fn resolve_one_write(
1342 future: impl Future<Output = std::result::Result<Vec<MultiWriteResult>, CheckWriteError>>,
1343 ) -> Result<MultiWriteResult> {
1344 let mut results = future.await?;
1345 Ok(results.remove(0))
1346 }
1347
1348 pub fn client(&self) -> &'a Client {
1350 self.client
1351 }
1352
1353 pub fn lock_path(&self) -> &str {
1358 &self.lock
1359 }
1360
1361 pub fn create(
1369 &self,
1370 path: &str,
1371 data: &[u8],
1372 options: &CreateOptions<'_>,
1373 ) -> impl Future<Output = Result<(Stat, CreateSequence)>> + Send + 'a {
1374 Client::wait(self.create_internally(path, data, options))
1375 }
1376
1377 fn create_internally(
1378 &self,
1379 path: &str,
1380 data: &[u8],
1381 options: &CreateOptions<'_>,
1382 ) -> Result<impl Future<Output = Result<(Stat, CreateSequence)>> + Send + 'a> {
1383 let mut writer = self.client.new_check_writer(&self.lock, None)?;
1384 writer.add_create(path, data, options)?;
1385 let write = writer.commit();
1386 let path_len = path.len();
1391 Ok(async move {
1392 let result = Self::resolve_one_write(write).await?;
1393 let (created_path, stat) = result.into_create()?;
1394 let sequence = if created_path.len() <= path_len {
1395 CreateSequence(-1)
1396 } else {
1397 Client::parse_sequence(&created_path, &created_path[..path_len])?
1398 };
1399 Ok((stat, sequence))
1400 })
1401 }
1402
1403 pub fn set_data(
1405 &self,
1406 path: &str,
1407 data: &[u8],
1408 expected_version: Option<i32>,
1409 ) -> impl Future<Output = Result<Stat>> + Send + 'a {
1410 Client::wait(self.set_data_internally(path, data, expected_version))
1411 }
1412
1413 fn set_data_internally(
1414 &self,
1415 path: &str,
1416 data: &[u8],
1417 expected_version: Option<i32>,
1418 ) -> Result<impl Future<Output = Result<Stat>> + Send + 'a> {
1419 let mut writer = self.new_check_writer();
1420 writer.add_set_data(path, data, expected_version)?;
1421 let write = writer.commit();
1422 Ok(async move {
1423 let result = Self::resolve_one_write(write).await?;
1424 let stat = result.into_set_data()?;
1425 Ok(stat)
1426 })
1427 }
1428
1429 pub fn delete(&self, path: &str, expected_version: Option<i32>) -> impl Future<Output = Result<()>> + Send + 'a {
1431 Client::wait(self.delete_internally(path, expected_version))
1432 }
1433
1434 fn delete_internally(
1435 &self,
1436 path: &str,
1437 expected_version: Option<i32>,
1438 ) -> Result<impl Future<Output = Result<()>> + Send + 'a> {
1439 let mut writer = self.new_check_writer();
1440 writer.add_delete(path, expected_version)?;
1441 let write = writer.commit();
1442 Ok(async move {
1443 let result = Self::resolve_one_write(write).await?;
1444 result.into_delete()
1445 })
1446 }
1447
1448 pub fn new_check_writer(&self) -> CheckWriter<'a> {
1450 unsafe { self.client.new_check_writer(&self.lock, None).unwrap_unchecked() }
1451 }
1452
1453 pub fn into_owned(self) -> OwnedLockClient {
1455 let client = self.client.clone();
1456 let mut drop = ManuallyDrop::new(self);
1457 let lock = std::mem::take(drop.lock.to_mut());
1458 OwnedLockClient { client: ManuallyDrop::new(client), lock }
1459 }
1460}
1461
1462impl Drop for LockClient<'_> {
1464 fn drop(&mut self) {
1465 let path = std::mem::take(self.lock.to_mut());
1466 let client = self.client.clone();
1467 client.delete_background(path);
1468 }
1469}
1470
1471#[derive(Clone, Debug)]
1473pub struct OwnedLockClient {
1474 client: ManuallyDrop<Client>,
1475 lock: String,
1476}
1477
1478impl OwnedLockClient {
1479 fn lock_client(&self) -> std::mem::ManuallyDrop<LockClient<'_>> {
1480 std::mem::ManuallyDrop::new(LockClient { client: &self.client, lock: Cow::from(&self.lock) })
1481 }
1482
1483 pub fn client(&self) -> &Client {
1485 &self.client
1486 }
1487
1488 pub fn lock_path(&self) -> &str {
1490 &self.lock
1491 }
1492
1493 pub fn create<'a: 'f, 'b: 'f, 'f>(
1495 &'a self,
1496 path: &'b str,
1497 data: &[u8],
1498 options: &CreateOptions<'_>,
1499 ) -> impl Future<Output = Result<(Stat, CreateSequence)>> + Send + 'f {
1500 self.lock_client().create(path, data, options)
1501 }
1502
1503 pub fn set_data(
1505 &self,
1506 path: &str,
1507 data: &[u8],
1508 expected_version: Option<i32>,
1509 ) -> impl Future<Output = Result<Stat>> + Send + '_ {
1510 self.lock_client().set_data(path, data, expected_version)
1511 }
1512
1513 pub fn delete(&self, path: &str, expected_version: Option<i32>) -> impl Future<Output = Result<()>> + Send + '_ {
1515 self.lock_client().delete(path, expected_version)
1516 }
1517
1518 pub fn new_check_writer(&self) -> CheckWriter<'_> {
1520 unsafe { self.client.new_check_writer(&self.lock, None).unwrap_unchecked() }
1521 }
1522}
1523
1524impl Drop for OwnedLockClient {
1526 fn drop(&mut self) {
1527 let client = unsafe { ManuallyDrop::take(&mut self.client) };
1528 let path = std::mem::take(&mut self.lock);
1529 client.delete_background(path);
1530 }
1531}
1532
1533#[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
1534pub(crate) struct Version(u32, u32, u32);
1535
1536#[derive(Clone, Debug)]
1540pub struct Connector {
1541 #[cfg(feature = "tls")]
1542 tls: Option<TlsOptions>,
1543 #[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
1544 sasl: Option<SaslOptions>,
1545 authes: Vec<AuthPacket>,
1546 session: Option<SessionInfo>,
1547 readonly: bool,
1548 detached: bool,
1549 fail_eagerly: bool,
1550 server_version: Version,
1551 session_timeout: Duration,
1552 connection_timeout: Duration,
1553}
1554
1555impl Connector {
1556 fn new() -> Self {
1557 Self {
1558 #[cfg(feature = "tls")]
1559 tls: None,
1560 #[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
1561 sasl: None,
1562 authes: Default::default(),
1563 session: None,
1564 readonly: false,
1565 detached: false,
1566 fail_eagerly: false,
1567 server_version: Version(u32::MAX, u32::MAX, u32::MAX),
1568 session_timeout: Duration::ZERO,
1569 connection_timeout: Duration::ZERO,
1570 }
1571 }
1572
1573 pub fn session_timeout(&mut self, timeout: Duration) -> &mut Self {
1577 self.session_timeout = timeout;
1578 self
1579 }
1580
1581 pub fn connection_timeout(&mut self, timeout: Duration) -> &mut Self {
1585 self.connection_timeout = timeout;
1586 self
1587 }
1588
1589 pub fn readonly(&mut self, readonly: bool) -> &mut Self {
1591 self.readonly = readonly;
1592 self
1593 }
1594
1595 pub fn auth(&mut self, scheme: String, auth: Vec<u8>) -> &mut Self {
1597 self.authes.push(AuthPacket { scheme, auth });
1598 self
1599 }
1600
1601 pub fn session(&mut self, session: SessionInfo) -> &mut Self {
1603 self.session = Some(session);
1604 self
1605 }
1606
1607 pub fn server_version(&mut self, major: u32, minor: u32, patch: u32) -> &mut Self {
1617 self.server_version = Version(major, minor, patch);
1618 self
1619 }
1620
1621 pub fn detached(&mut self) -> &mut Self {
1623 self.detached = true;
1624 self
1625 }
1626
1627 #[cfg(feature = "tls")]
1629 pub fn tls(&mut self, options: TlsOptions) -> &mut Self {
1630 self.tls = Some(options);
1631 self
1632 }
1633
1634 #[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
1636 pub fn sasl(&mut self, options: impl Into<SaslOptions>) -> &mut Self {
1637 self.sasl = Some(options.into());
1638 self
1639 }
1640
1641 pub fn fail_eagerly(&mut self) -> &mut Self {
1646 self.fail_eagerly = true;
1647 self
1648 }
1649
1650 #[instrument(name = "connect", skip_all, fields(session))]
1651 async fn connect_internally(&mut self, secure: bool, cluster: &str) -> Result<Client> {
1652 let (endpoints, chroot) = endpoint::parse_connect_string(cluster, secure)?;
1653 let builder = Session::builder()
1654 .with_session(self.session.take())
1655 .with_authes(&self.authes)
1656 .with_readonly(self.readonly)
1657 .with_detached(self.detached)
1658 .with_session_timeout(self.session_timeout)
1659 .with_connection_timeout(self.connection_timeout);
1660 #[cfg(feature = "tls")]
1661 let builder = builder.with_tls(self.tls.take());
1662 #[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
1663 let builder = builder.with_sasl(self.sasl.take());
1664 let (mut session, state_receiver) = builder.build()?;
1665 let mut endpoints = IterableEndpoints::from(endpoints.as_slice());
1666 endpoints.reset();
1667 if !self.fail_eagerly {
1668 endpoints.cycle();
1669 }
1670 let mut buf = Vec::with_capacity(4096);
1671 let mut depot = Depot::new();
1672 let conn = session.start(&mut endpoints, &mut buf, &mut depot).await?;
1673 let (sender, receiver) = mpsc::unbounded();
1674 let session_info = session.session.clone();
1675 let session_timeout = session.session_timeout;
1676 let mut state_watcher = StateWatcher::new(state_receiver);
1677 state_watcher.state();
1679 asyncs::spawn(async move {
1680 session.serve(endpoints, conn, buf, depot, receiver).await;
1681 });
1682 let client =
1683 Client::new(chroot.to_owned(), self.server_version, session_info, session_timeout, sender, state_watcher);
1684 Ok(client)
1685 }
1686
1687 #[cfg(feature = "tls")]
1692 pub async fn secure_connect(&mut self, cluster: &str) -> Result<Client> {
1693 self.connect_internally(true, cluster).await
1694 }
1695
1696 pub async fn connect(&mut self, cluster: &str) -> Result<Client> {
1712 self.connect_internally(false, cluster).await
1713 }
1714}
1715
1716#[derive(Clone, Debug)]
1718pub struct ClientBuilder {
1719 connector: Connector,
1720}
1721
1722impl ClientBuilder {
1723 fn new() -> Self {
1724 Self { connector: Connector::new() }
1725 }
1726
1727 pub fn with_session_timeout(&mut self, timeout: Duration) -> &mut Self {
1731 self.connector.session_timeout(timeout);
1732 self
1733 }
1734
1735 pub fn with_connection_timeout(&mut self, timeout: Duration) -> &mut Self {
1739 self.connector.connection_timeout(timeout);
1740 self
1741 }
1742
1743 pub fn with_readonly(&mut self, readonly: bool) -> &mut ClientBuilder {
1745 self.connector.readonly = readonly;
1746 self
1747 }
1748
1749 pub fn with_auth(&mut self, scheme: String, auth: Vec<u8>) -> &mut ClientBuilder {
1751 self.connector.auth(scheme, auth);
1752 self
1753 }
1754
1755 pub fn assume_server_version(&mut self, major: u32, minor: u32, patch: u32) -> &mut Self {
1765 self.connector.server_version(major, minor, patch);
1766 self
1767 }
1768
1769 pub fn detach(&mut self) -> &mut Self {
1771 self.connector.detached();
1772 self
1773 }
1774
1775 pub async fn connect(&mut self, cluster: &str) -> Result<Client> {
1781 self.connector.connect(cluster).await
1782 }
1783}
1784
1785trait MultiBuffer {
1786 fn buffer(&mut self) -> &mut Vec<u8>;
1787
1788 fn op_code() -> OpCode;
1789
1790 fn build_request(&mut self) -> MarshalledRequest {
1791 let buffer = self.buffer();
1792 if buffer.is_empty() {
1793 return Default::default();
1794 }
1795 let header = MultiHeader { op: OpCode::Error, done: true, err: -1 };
1796 buffer.append_record(&header);
1797 buffer.finish();
1798 MarshalledRequest(std::mem::take(buffer))
1799 }
1800
1801 fn add_operation(&mut self, op: OpCode, request: &impl Record) {
1802 let buffer = self.buffer();
1803 if buffer.is_empty() {
1804 let n = RequestHeader::record_len() + MultiHeader::record_len() + request.serialized_len();
1805 buffer.prepare_and_reserve(n);
1806 buffer.append_record(&RequestHeader::with_code(Self::op_code()));
1807 }
1808 let header = MultiHeader { op, done: false, err: -1 };
1809 self.buffer().append_record2(&header, request);
1810 }
1811}
1812
1813#[non_exhaustive]
1815#[derive(Debug)]
1816pub enum MultiReadResult {
1817 Data { data: Vec<u8>, stat: Stat },
1819
1820 Children { children: Vec<String> },
1822
1823 Error { err: Error },
1825}
1826
1827pub struct MultiReader<'a> {
1829 client: &'a Client,
1830 buf: Vec<u8>,
1831}
1832
1833impl MultiBuffer for MultiReader<'_> {
1834 fn buffer(&mut self) -> &mut Vec<u8> {
1835 &mut self.buf
1836 }
1837
1838 fn op_code() -> OpCode {
1839 OpCode::MultiRead
1840 }
1841}
1842
1843impl<'a> MultiReader<'a> {
1844 fn new(client: &'a Client) -> MultiReader<'a> {
1845 MultiReader { client, buf: Default::default() }
1846 }
1847
1848 pub fn add_get_data(&mut self, path: &str) -> Result<()> {
1852 let chroot_path = self.client.validate_path(path)?;
1853 let request = GetRequest { path: chroot_path, watch: false };
1854 self.add_operation(OpCode::GetData, &request);
1855 Ok(())
1856 }
1857
1858 pub fn add_get_children(&mut self, path: &str) -> Result<()> {
1862 let chroot_path = self.client.validate_path(path)?;
1863 let request = GetChildrenRequest { path: chroot_path, watch: false };
1864 self.add_operation(OpCode::GetChildren, &request);
1865 Ok(())
1866 }
1867
1868 pub fn commit(&mut self) -> impl Future<Output = Result<Vec<MultiReadResult>>> + Send + 'a {
1873 let request = self.build_request();
1874 Client::resolve(self.commit_internally(request))
1875 }
1876
1877 fn commit_internally(
1878 &self,
1879 request: MarshalledRequest,
1880 ) -> Result<Either<impl Future<Output = Result<Vec<MultiReadResult>>> + Send + 'a, Vec<MultiReadResult>>> {
1881 if request.is_empty() {
1882 return Ok(Right(Vec::default()));
1883 }
1884 let receiver = self.client.send_marshalled_request(request);
1885 Ok(Left(async move {
1886 let (body, _) = receiver.await?;
1887 let response = record::unmarshal::<Vec<MultiReadResponse>>(&mut body.as_slice())?;
1888 let mut results = Vec::with_capacity(response.len());
1889 for result in response {
1890 match result {
1891 MultiReadResponse::Data { data, stat } => results.push(MultiReadResult::Data { data, stat }),
1892 MultiReadResponse::Children { children } => results.push(MultiReadResult::Children { children }),
1893 MultiReadResponse::Error(err) => results.push(MultiReadResult::Error { err }),
1894 }
1895 }
1896 Ok(results)
1897 }))
1898 }
1899
1900 pub fn abort(&mut self) {
1902 self.buf.clear();
1903 }
1904}
1905
1906#[non_exhaustive]
1908#[derive(Debug, PartialEq, Eq)]
1909pub enum MultiWriteResult {
1910 Check,
1912
1913 Delete,
1915
1916 Create {
1918 path: String,
1920
1921 stat: Stat,
1928 },
1929
1930 SetData {
1932 stat: Stat,
1934 },
1935}
1936
1937impl MultiWriteResult {
1938 fn kind(&self) -> &'static str {
1939 match self {
1940 MultiWriteResult::Check => "MultiWriteResult::Check",
1941 MultiWriteResult::Create { .. } => "MultiWriteResult::Create",
1942 MultiWriteResult::Delete => "MultiWriteResult::Delete",
1943 MultiWriteResult::SetData { .. } => "MultiWriteResult::SetData",
1944 }
1945 }
1946
1947 fn into_check(self) -> Result<()> {
1948 match self {
1949 MultiWriteResult::Check => Ok(()),
1950 _ => Err(Error::UnexpectedError(format!("expect MultiWriteResult::Check, got {}", self.kind()))),
1951 }
1952 }
1953
1954 fn into_create(self) -> Result<(String, Stat)> {
1955 match self {
1956 MultiWriteResult::Create { path, stat } => Ok((path, stat)),
1957 _ => Err(Error::UnexpectedError(format!("expect MultiWriteResult::Create, got {}", self.kind()))),
1958 }
1959 }
1960
1961 fn into_set_data(self) -> Result<Stat> {
1962 match self {
1963 MultiWriteResult::SetData { stat } => Ok(stat),
1964 _ => Err(Error::UnexpectedError(format!("expect MultiWriteResult::SetData, got {}", self.kind()))),
1965 }
1966 }
1967
1968 fn into_delete(self) -> Result<()> {
1969 match self {
1970 MultiWriteResult::Delete => Ok(()),
1971 _ => Err(Error::UnexpectedError(format!("expect MultiWriteResult::Delete, got {}", self.kind()))),
1972 }
1973 }
1974}
1975
1976#[derive(Error, Clone, Debug, PartialEq, Eq)]
1978pub enum MultiWriteError {
1979 #[error("{source}")]
1980 RequestFailed {
1981 #[from]
1982 source: Error,
1983 },
1984
1985 #[error("operation at index {index} failed: {source}")]
1986 OperationFailed { index: usize, source: Error },
1987}
1988
1989impl From<MultiWriteError> for Error {
1990 fn from(err: MultiWriteError) -> Self {
1991 match err {
1992 MultiWriteError::RequestFailed { source } => source,
1993 MultiWriteError::OperationFailed { source, .. } => source,
1994 }
1995 }
1996}
1997
1998#[derive(Error, Clone, Debug, PartialEq, Eq)]
2000pub enum CheckWriteError {
2001 #[error("request failed: {source}")]
2002 RequestFailed {
2003 #[from]
2004 source: Error,
2005 },
2006
2007 #[error("path check failed: {source}")]
2008 CheckFailed { source: Error },
2009
2010 #[error("operation at index {index} failed: {source}")]
2011 OperationFailed { index: usize, source: Error },
2012}
2013
2014impl From<MultiWriteError> for CheckWriteError {
2015 fn from(err: MultiWriteError) -> Self {
2016 match err {
2017 MultiWriteError::RequestFailed { source } => CheckWriteError::RequestFailed { source },
2018 MultiWriteError::OperationFailed { index: 0, source } => CheckWriteError::CheckFailed { source },
2019 MultiWriteError::OperationFailed { index, source } => {
2020 CheckWriteError::OperationFailed { index: index - 1, source }
2021 },
2022 }
2023 }
2024}
2025
2026impl From<CheckWriteError> for Error {
2027 fn from(err: CheckWriteError) -> Self {
2028 match err {
2029 CheckWriteError::RequestFailed { source } => source,
2030 CheckWriteError::CheckFailed { source: Error::NoNode | Error::BadVersion } => Error::RuntimeInconsistent,
2031 CheckWriteError::CheckFailed { source } => source,
2032 CheckWriteError::OperationFailed { source, .. } => source,
2033 }
2034 }
2035}
2036
2037pub struct CheckWriter<'a> {
2039 writer: MultiWriter<'a>,
2040}
2041
2042impl<'a> CheckWriter<'a> {
2043 pub fn add_check_version(&mut self, path: &str, version: i32) -> Result<()> {
2045 self.writer.add_check_version(path, version)
2046 }
2047
2048 pub fn add_create(&mut self, path: &str, data: &[u8], options: &CreateOptions<'_>) -> Result<()> {
2050 self.writer.add_create(path, data, options)
2051 }
2052
2053 pub fn add_set_data(&mut self, path: &str, data: &[u8], expected_version: Option<i32>) -> Result<()> {
2055 self.writer.add_set_data(path, data, expected_version)
2056 }
2057
2058 pub fn add_delete(&mut self, path: &str, expected_version: Option<i32>) -> Result<()> {
2060 self.writer.add_delete(path, expected_version)
2061 }
2062
2063 pub fn commit(
2065 mut self,
2066 ) -> impl Future<Output = std::result::Result<Vec<MultiWriteResult>, CheckWriteError>> + Send + 'a {
2067 let commit = self.writer.commit();
2068 async move {
2069 let mut results = commit.await?;
2070 if results.is_empty() {
2071 Err(CheckWriteError::RequestFailed {
2072 source: Error::UnexpectedError("expect path check, got none".to_string()),
2073 })
2074 } else {
2075 results.remove(0).into_check()?;
2076 Ok(results)
2077 }
2078 }
2079 }
2080}
2081
2082pub struct MultiWriter<'a> {
2084 client: &'a Client,
2085 buf: Vec<u8>,
2086}
2087
2088impl MultiBuffer for MultiWriter<'_> {
2089 fn buffer(&mut self) -> &mut Vec<u8> {
2090 &mut self.buf
2091 }
2092
2093 fn op_code() -> OpCode {
2094 OpCode::Multi
2095 }
2096}
2097
2098impl<'a> MultiWriter<'a> {
2099 fn new(client: &'a Client) -> MultiWriter<'a> {
2100 MultiWriter { client, buf: Default::default() }
2101 }
2102
2103 pub fn add_check_version(&mut self, path: &str, version: i32) -> Result<()> {
2108 let chroot_path = self.client.validate_path(path)?;
2109 let request = CheckVersionRequest { path: chroot_path, version };
2110 self.add_operation(OpCode::Check, &request);
2111 Ok(())
2112 }
2113
2114 pub fn add_create(&mut self, path: &str, data: &[u8], options: &CreateOptions<'_>) -> Result<()> {
2125 options.validate()?;
2126 let ttl = options.ttl.map(|ttl| ttl.as_millis() as i64).unwrap_or(0);
2127 let create_mode = options.mode;
2128 let sequential = create_mode.is_sequential();
2129 let chroot_path =
2130 if sequential { self.client.validate_sequential_path(path)? } else { self.client.validate_path(path)? };
2131 let op_code = if ttl != 0 {
2132 OpCode::CreateTtl
2133 } else if create_mode.is_container() {
2134 OpCode::CreateContainer
2135 } else {
2136 OpCode::Create2
2137 };
2138 let flags = create_mode.as_flags(ttl != 0);
2139 let request = CreateRequest { path: chroot_path, data, acls: options.acls, flags, ttl };
2140 self.add_operation(op_code, &request);
2141 Ok(())
2142 }
2143
2144 pub fn add_set_data(&mut self, path: &str, data: &[u8], expected_version: Option<i32>) -> Result<()> {
2148 let chroot_path = self.client.validate_path(path)?;
2149 let request = SetDataRequest { path: chroot_path, data, version: expected_version.unwrap_or(-1) };
2150 self.add_operation(OpCode::SetData, &request);
2151 Ok(())
2152 }
2153
2154 pub fn add_delete(&mut self, path: &str, expected_version: Option<i32>) -> Result<()> {
2158 let chroot_path = self.client.validate_path(path)?;
2159 if chroot_path.is_root() {
2160 return Err(Error::BadArguments(&"can not delete root node"));
2161 }
2162 let request = DeleteRequest { path: chroot_path, version: expected_version.unwrap_or(-1) };
2163 self.add_operation(OpCode::Delete, &request);
2164 Ok(())
2165 }
2166
2167 pub fn commit(
2175 &mut self,
2176 ) -> impl Future<Output = std::result::Result<Vec<MultiWriteResult>, MultiWriteError>> + Send + 'a {
2177 let request = self.build_request();
2178 Client::resolve(self.commit_internally(request))
2179 }
2180
2181 #[allow(clippy::type_complexity)]
2182 fn commit_internally(
2183 &self,
2184 request: MarshalledRequest,
2185 ) -> Result<
2186 Either<impl Future<Output = Result<Vec<MultiWriteResult>, MultiWriteError>> + Send + 'a, Vec<MultiWriteResult>>,
2187 MultiWriteError,
2188 > {
2189 if request.is_empty() {
2190 return Ok(Right(Vec::default()));
2191 }
2192 let receiver = self.client.send_marshalled_request(request);
2193 let client = self.client;
2194 Ok(Left(async move {
2195 let (body, _) = receiver.await?;
2196 let response = record::unmarshal::<Vec<MultiWriteResponse>>(&mut body.as_slice())?;
2197 let failed = response.first().map(|r| matches!(r, MultiWriteResponse::Error(_))).unwrap_or(false);
2198 let mut results = if failed { Vec::new() } else { Vec::with_capacity(response.len()) };
2199 for (index, result) in response.into_iter().enumerate() {
2200 match result {
2201 MultiWriteResponse::Check => results.push(MultiWriteResult::Check),
2202 MultiWriteResponse::Delete => results.push(MultiWriteResult::Delete),
2203 MultiWriteResponse::Create { mut path, stat } => {
2204 path = util::strip_root_path(path, client.chroot.root())?;
2205 results.push(MultiWriteResult::Create { path: path.to_string(), stat });
2206 },
2207 MultiWriteResponse::SetData { stat } => results.push(MultiWriteResult::SetData { stat }),
2208 MultiWriteResponse::Error(Error::UnexpectedErrorCode(0)) => {},
2209 MultiWriteResponse::Error(err) => {
2210 return Err(MultiWriteError::OperationFailed { index, source: err })
2211 },
2212 }
2213 }
2214 Ok(results)
2215 }))
2216 }
2217
2218 pub fn abort(&mut self) {
2220 self.buf.clear();
2221 }
2222}
2223
2224#[cfg(test)]
2225mod tests {
2226 use assertor::*;
2227
2228 use super::*;
2229
2230 #[test]
2231 fn test_create_options_validate() {
2232 assert_that!(CreateMode::Persistent.with_acls(Acls::new(Default::default())).validate().unwrap_err())
2233 .is_equal_to(Error::InvalidAcl);
2234
2235 let acls = Acls::anyone_all();
2236
2237 assert_that!(CreateMode::Ephemeral.with_acls(acls).with_ttl(Duration::from_secs(1)).validate().unwrap_err())
2238 .is_equal_to(Error::BadArguments(&"ttl can only be specified with persistent node"));
2239
2240 assert_that!(CreateMode::Persistent.with_acls(acls).with_ttl(Duration::ZERO).validate().unwrap_err())
2241 .is_equal_to(Error::BadArguments(&"ttl is zero"));
2242
2243 assert_that!(CreateMode::Persistent
2244 .with_acls(acls)
2245 .with_ttl(Duration::from_millis(0x01FFFFFFFFFF))
2246 .validate()
2247 .unwrap_err())
2248 .is_equal_to(Error::BadArguments(&"ttl cannot larger than 1099511627775"));
2249
2250 assert_that!(CreateMode::Persistent.with_acls(acls).with_ttl(Duration::from_secs(5)).validate())
2251 .is_equal_to(Ok(()));
2252 }
2253
2254 #[test]
2255 fn test_lock_options_with_ancestor_options() {
2256 let options = LockOptions::new(Acls::anyone_all());
2257 assert_that!(options
2258 .clone()
2259 .with_ancestor_options(CreateMode::Ephemeral.with_acls(Acls::anyone_all()))
2260 .unwrap_err())
2261 .is_equal_to(Error::BadArguments(&"directory node must not be ephemeral"));
2262 assert_that!(options
2263 .with_ancestor_options(CreateMode::PersistentSequential.with_acls(Acls::anyone_all()))
2264 .unwrap_err())
2265 .is_equal_to(Error::BadArguments(&"directory node must not be sequential"));
2266 }
2267
2268 #[test_log::test(asyncs::test)]
2269 async fn session_last_zxid_seen() {
2270 use testcontainers::clients::Cli as DockerCli;
2271 use testcontainers::core::{Healthcheck, WaitFor};
2272 use testcontainers::images::generic::GenericImage;
2273
2274 let healthcheck = Healthcheck::default()
2275 .with_cmd(["./bin/zkServer.sh", "status"].iter())
2276 .with_interval(Duration::from_secs(2))
2277 .with_retries(60);
2278 let image =
2279 GenericImage::new("zookeeper", "3.9.0").with_healthcheck(healthcheck).with_wait_for(WaitFor::Healthcheck);
2280 let docker = DockerCli::default();
2281 let container = docker.run(image);
2282 let endpoint = format!("127.0.0.1:{}", container.get_host_port(2181));
2283
2284 let client1 = Client::connector().detached().connect(&endpoint).await.unwrap();
2285 client1.create("/n1", b"", &CreateMode::Persistent.with_acls(Acls::anyone_all())).await.unwrap();
2286
2287 let mut session = client1.into_session();
2288
2289 session.last_zxid = i64::MAX;
2291 assert_that!(Client::connector().fail_eagerly().session(session.clone()).connect(&endpoint).await.unwrap_err())
2292 .is_equal_to(Error::NoHosts);
2293
2294 session.last_zxid = 0;
2296 let client2 = Client::connector().fail_eagerly().session(session.clone()).connect(&endpoint).await.unwrap();
2297 client2.create("/n2", b"", &CreateMode::Persistent.with_acls(Acls::anyone_all())).await.unwrap();
2298 }
2299}