1mod error;
2mod lease;
3mod lock;
4
5use std::sync::Arc;
6
7pub use crate::{error::*, lease::*, lock::*};
8use etcd_client::{
9 Client, Compare, CompareOp, Event, PutOptions, ResponseHeader, Txn, TxnOp, WatchOptions,
10 WatchStream,
11};
12use futures::prelude::*;
13use log::*;
14use rand::random;
15
16pub enum TtlOrLeaseId {
17 Ttl(i64),
18 LeaseId(i64),
19}
20
21pub async fn try_race(
22 client: &mut Client,
23 key: &str,
24 ttl_or_lease_id: TtlOrLeaseId,
25) -> CrateResult<bool> {
26 let value = random::<u128>().to_le_bytes();
27
28 let lease_id = match ttl_or_lease_id {
29 TtlOrLeaseId::Ttl(ttl) => client.lease_grant(ttl, None).await?.id(),
30 TtlOrLeaseId::LeaseId(id) => id,
31 };
32
33 let locked = client
34 .txn(
35 Txn::new()
36 .when(vec![Compare::create_revision(
37 key.as_bytes(),
38 CompareOp::Equal,
39 0,
40 )])
41 .and_then(vec![TxnOp::put(
42 key.as_bytes(),
43 &value[..],
44 Some(PutOptions::new().with_lease(lease_id)),
45 )]),
46 )
47 .await?
48 .succeeded();
49
50 Ok(locked)
51}
52
53fn get_lock_txn(name: &str, lock_value: &[u8], lease_id: i64) -> Txn {
54 Txn::new()
55 .when([Compare::create_revision(name, CompareOp::Equal, 0)])
56 .and_then([TxnOp::put(
57 name,
58 lock_value,
59 Some(PutOptions::new().with_lease(lease_id)),
60 )])
61}
62
63pub(crate) fn get_release_txn(name: &str, lock_value: &[u8], lease_id: i64) -> Txn {
64 Txn::new()
65 .when([
66 Compare::value(name, CompareOp::Equal, lock_value),
67 Compare::lease(name, CompareOp::Equal, lease_id),
68 ])
69 .and_then([TxnOp::delete(name, None)])
70}
71
72async fn watch_until_lock_available(
73 etcd_client: &mut Client,
74 watch_stream: &mut Option<WatchStream>,
75 name: &str,
76 last_locked_revision: Option<i64>,
77) -> CrateResult<()> {
78 loop {
79 let mut etcd_client = etcd_client.clone();
80 let mut inner_watch_stream = match watch_stream.take() {
81 Some(watch_stream) => watch_stream,
82 None => {
83 let mut watch_options = WatchOptions::new();
84 if let Some(last_locked_revision) = last_locked_revision {
85 watch_options = watch_options.with_start_revision(last_locked_revision + 1);
86 }
87
88 let (_, watch_stream) = etcd_client.watch(name, Some(watch_options)).await?;
89
90 watch_stream
91 }
92 };
93
94 let watch_response = inner_watch_stream
95 .try_next()
96 .await?
97 .expect("Unexpected end of etcd watch stream");
98
99 *watch_stream = Some(inner_watch_stream);
100
101 let is_available = matches!(watch_response.events().last().and_then(Event::kv), Some(kv) if kv.create_revision() == 0);
102
103 trace!(
104 "Got update from watch stream for key: {} (avilable: {})",
105 name,
106 is_available
107 );
108
109 if is_available {
110 return Ok(());
111 }
112 }
113}
114
115impl Lease {
116 pub(crate) async fn stop_keep_alive(self) -> Result<(), etcd_client::Error> {
117 if let Ok(LeaseInner {
118 stop_tx,
119 keep_alive_handle,
120 }) = Arc::try_unwrap(self.inner)
121 {
122 let _ = stop_tx.send(());
123 keep_alive_handle.await.unwrap()?;
124 }
125
126 Ok(())
127 }
128
129 pub async fn release(self, client: &mut Client) -> Result<(), etcd_client::Error> {
139 let lease_id = self.lease_id;
140
141 self.stop_keep_alive().await?;
142 client.lease_revoke(lease_id).await?;
143
144 debug!("Released lease: {:#x}", lease_id);
145
146 Ok(())
147 }
148
149 async fn _lock<'l, 'n>(
150 &'l mut self,
151 client: &mut Client,
152 name: &'n str,
153 wait: bool,
154 ) -> CrateResult<LockGuard<'l, 'n>> {
155 let lock_value = random::<u128>().to_le_bytes();
156
157 let mut watch_stream = None;
158 let locked = loop {
159 let txn_response = client
160 .txn(get_lock_txn(name, &lock_value, self.lease_id))
161 .await?;
162
163 if txn_response.succeeded() {
164 break true;
165 } else if !wait {
166 break false;
167 }
168
169 watch_until_lock_available(
170 client,
171 &mut watch_stream,
172 name,
173 txn_response.header().map(ResponseHeader::revision),
174 )
175 .await?;
176 };
177
178 if !locked {
179 return Err(Error::Taken);
180 } else {
181 debug!("Acquired lock: {} (lease: {:#x})", name, self.lease_id);
182 }
183
184 Ok(LockGuard {
185 lock_value,
186 name,
187 lease: &*self,
188 })
189 }
190
191 pub async fn lock<'l, 'n>(
192 &'l mut self,
193 client: &mut Client,
194 name: &'n str,
195 ) -> CrateResult<LockGuard<'l, 'n>> {
196 self._lock(client, name, true).await
197 }
198
199 pub async fn try_lock<'l, 'n>(
205 &'l mut self,
206 client: &mut Client,
207 name: &'n str,
208 ) -> CrateResult<LockGuard<'l, 'n>> {
209 self._lock(client, name, false).await
210 }
211
212 pub async fn with_lock<F, Fut, T>(
239 &mut self,
240 client: &mut Client,
241 name: &str,
242 f: F,
243 ) -> CrateResult<T>
244 where
245 F: FnOnce() -> Fut,
246 Fut: Future<Output = T>,
247 {
248 let guard = self.lock(client, name).await?;
249 let result = f().await;
250 guard.release(client).await?;
251
252 Ok(result)
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use rand::{random, random_range};
260 use std::{
261 ops::AddAssign,
262 sync::{
263 atomic::{AtomicUsize, Ordering},
264 Arc,
265 },
266 time::Duration,
267 };
268 use tokio::{
269 sync::{Barrier, Mutex},
270 time::sleep,
271 };
272
273 async fn create_client() -> CrateResult<Client> {
274 Ok(Client::connect(vec!["http://localhost:2379"], None).await?)
275 }
276
277 fn create_lock_name() -> String {
278 format!("{:x}", random::<u128>())
279 }
280
281 #[tokio::test]
282 async fn test_mutex() -> anyhow::Result<()> {
283 const WORKERS: usize = 10;
284
285 type Result<T> = anyhow::Result<T, anyhow::Error>;
286
287 let lock_name = create_lock_name();
288
289 let client = create_client().await?;
290 let locked_count = Arc::new(Mutex::new(0_usize));
291
292 stream::iter(0..WORKERS)
293 .map(Result::Ok)
294 .try_for_each_concurrent(None, |_| {
295 let mut client = client.clone();
296 let lock_name = lock_name.clone();
297 let locked_count = locked_count.clone();
298
299 async move {
300 let mut lease = acquire_lease(&mut client, 10).await?;
301
302 lease
303 .with_lock(&mut client, &lock_name, || async move {
304 {
305 let mut locked_count = locked_count.try_lock()?;
306 sleep(Duration::from_millis(random_range(0..2_000))).await;
307 locked_count.add_assign(1);
308 }
309
310 Result::Ok(())
311 })
312 .await??;
313
314 Result::Ok(())
315 }
316 })
317 .await?;
318
319 assert_eq!(*locked_count.try_lock().unwrap(), WORKERS);
320
321 Ok(())
322 }
323
324 #[tokio::test]
325 async fn test_lease_expire() -> CrateResult<()> {
326 let lock_name = create_lock_name();
327
328 let mut client = create_client().await?;
329
330 let lease1 = acquire_lease(&mut client, 10).await?;
331 let mut lease2 = acquire_lease(&mut client, 10).await?;
332
333 let barrier1 = Arc::new(Barrier::new(2));
334 let barrier2 = Arc::new(Barrier::new(2));
335
336 future::try_join(
337 {
338 let mut client = client.clone();
339 let lock_name = lock_name.clone();
340 let barrier1 = barrier1.clone();
341 let barrier2 = barrier2.clone();
342 async move {
343 client
344 .txn(get_lock_txn(&lock_name, b"", lease1.lease_id))
345 .await?;
346 barrier1.wait().await;
347 barrier2.wait().await;
348 lease1.stop_keep_alive().await?;
349 Ok(())
350 }
351 },
352 {
353 let mut client = client.clone();
354
355 async move {
356 barrier1.wait().await;
357
358 assert!(matches!(
360 lease2.try_lock(&mut client, &lock_name).await.unwrap_err(),
361 Error::Taken
362 ));
363
364 barrier2.wait().await;
366
367 lease2
368 .with_lock(&mut client, &lock_name, || async {
369 info!("Inside lock.");
370 })
371 .await
372 }
373 },
374 )
375 .await?;
376
377 Ok(())
378 }
379
380 #[tokio::test]
381 async fn test_lock_lost() -> CrateResult<()> {
382 let lock_name = create_lock_name();
383
384 let mut client = create_client().await?;
385 let mut lease = acquire_lease(&mut client, 10).await?;
386
387 let lease_id = lease.lease_id;
388
389 let res = lease
390 .with_lock(&mut client.clone(), &lock_name, || async {
391 client.lease_revoke(lease_id).await.unwrap();
392 })
393 .await;
394
395 assert!(matches!(res, Err(Error::Lost)));
396
397 Ok(())
398 }
399
400 #[tokio::test]
401 async fn test_try_race() -> CrateResult<()> {
402 const TTL: u64 = 5;
403
404 let lock_name = create_lock_name();
405
406 let mut client = create_client().await?;
407 let locked_count = Arc::new(AtomicUsize::new(0));
408
409 let wait_for_ttl = sleep(Duration::from_secs(TTL + 1));
410
411 stream::iter(0..20)
412 .map(|_| {
413 let mut client = client.clone();
414 let lock_name = lock_name.clone();
415 let locked_count = locked_count.clone();
416
417 async move {
418 sleep(Duration::from_secs(random_range(0..TTL - 1))).await;
419
420 if try_race(&mut client, &lock_name, TtlOrLeaseId::Ttl(TTL as i64)).await? {
421 locked_count.fetch_add(1, Ordering::SeqCst);
422 }
423
424 CrateResult::Ok(())
425 }
426 })
427 .buffer_unordered(20)
428 .try_collect::<Vec<_>>()
429 .await?;
430
431 assert_eq!(locked_count.load(Ordering::SeqCst), 1);
432
433 wait_for_ttl.await;
434 assert!(try_race(&mut client, &lock_name, TtlOrLeaseId::Ttl(10)).await?);
435
436 Ok(())
437 }
438}