1use async_trait::async_trait;
16use std::{collections::HashMap, path::Path, sync::Arc, time::Duration};
17use tokio::sync::RwLock;
18use uuid::Uuid;
19
20use crate::{
21 LockApi,
22 drwmutex::{DRWMutex, Options},
23 lrwmutex::LRWMutex,
24};
25use std::io::Result;
26
27pub type RWLockerImpl = Box<dyn RWLocker + Send + Sync>;
28
29#[async_trait]
30pub trait RWLocker {
31 async fn get_lock(&mut self, opts: &Options) -> Result<bool>;
32 async fn un_lock(&mut self) -> Result<()>;
33 async fn get_u_lock(&mut self, opts: &Options) -> Result<bool>;
34 async fn un_r_lock(&mut self) -> Result<()>;
35}
36
37#[derive(Debug)]
38struct NsLock {
39 reference: usize,
40 lock: LRWMutex,
41}
42
43#[derive(Debug, Default)]
44pub struct NsLockMap {
45 is_dist_erasure: bool,
46 lock_map: RwLock<HashMap<String, NsLock>>,
47}
48
49impl NsLockMap {
50 pub fn new(is_dist_erasure: bool) -> Self {
51 Self {
52 is_dist_erasure,
53 ..Default::default()
54 }
55 }
56
57 async fn lock(
58 &mut self,
59 volume: &String,
60 path: &String,
61 lock_source: &str,
62 ops_id: &str,
63 read_lock: bool,
64 timeout: Duration,
65 ) -> bool {
66 let resource = Path::new(volume).join(path).to_str().unwrap().to_string();
67 let mut w_lock_map = self.lock_map.write().await;
68 let nslk = w_lock_map.entry(resource.clone()).or_insert(NsLock {
69 reference: 0,
70 lock: LRWMutex::default(),
71 });
72 nslk.reference += 1;
73
74 let locked = if read_lock {
75 nslk.lock.get_r_lock(ops_id, lock_source, &timeout).await
76 } else {
77 nslk.lock.get_lock(ops_id, lock_source, &timeout).await
78 };
79
80 if !locked {
81 nslk.reference -= 1;
82 if nslk.reference == 0 {
83 w_lock_map.remove(&resource);
84 }
85 }
86
87 locked
88 }
89
90 async fn un_lock(&mut self, volume: &String, path: &String, read_lock: bool) {
91 let resource = Path::new(volume).join(path).to_str().unwrap().to_string();
92 let mut w_lock_map = self.lock_map.write().await;
93 if let Some(nslk) = w_lock_map.get_mut(&resource) {
94 if read_lock {
95 nslk.lock.un_r_lock().await;
96 } else {
97 nslk.lock.un_lock().await;
98 }
99
100 nslk.reference -= 1;
101
102 if nslk.reference == 0 {
103 w_lock_map.remove(&resource);
104 }
105 }
106 }
107}
108
109pub struct WrapperLocker(pub Arc<RwLock<RWLockerImpl>>);
110
111impl Drop for WrapperLocker {
112 fn drop(&mut self) {
113 let inner = self.0.clone();
114 tokio::spawn(async move {
115 let _ = inner.write().await.un_lock().await;
116 });
117 }
118}
119
120pub async fn new_nslock(
121 ns: Arc<RwLock<NsLockMap>>,
122 owner: String,
123 volume: String,
124 paths: Vec<String>,
125 lockers: Vec<LockApi>,
126) -> WrapperLocker {
127 if ns.read().await.is_dist_erasure {
128 let names = paths
129 .iter()
130 .map(|path| Path::new(&volume).join(path).to_str().unwrap().to_string())
131 .collect();
132 return WrapperLocker(Arc::new(RwLock::new(Box::new(DistLockInstance::new(owner, names, lockers)))));
133 }
134
135 WrapperLocker(Arc::new(RwLock::new(Box::new(LocalLockInstance::new(ns, volume, paths)))))
136}
137
138struct DistLockInstance {
139 lock: Box<DRWMutex>,
140 ops_id: String,
141}
142
143impl DistLockInstance {
144 fn new(owner: String, names: Vec<String>, lockers: Vec<LockApi>) -> Self {
145 let ops_id = Uuid::new_v4().to_string();
146 Self {
147 lock: Box::new(DRWMutex::new(owner, names, lockers)),
148 ops_id,
149 }
150 }
151}
152
153#[async_trait]
154impl RWLocker for DistLockInstance {
155 async fn get_lock(&mut self, opts: &Options) -> Result<bool> {
156 let source = "".to_string();
157
158 Ok(self.lock.get_lock(&self.ops_id, &source, opts).await)
159 }
160
161 async fn un_lock(&mut self) -> Result<()> {
162 self.lock.un_lock().await;
163 Ok(())
164 }
165
166 async fn get_u_lock(&mut self, opts: &Options) -> Result<bool> {
167 let source = "".to_string();
168
169 Ok(self.lock.get_r_lock(&self.ops_id, &source, opts).await)
170 }
171
172 async fn un_r_lock(&mut self) -> Result<()> {
173 self.lock.un_r_lock().await;
174 Ok(())
175 }
176}
177
178struct LocalLockInstance {
179 ns: Arc<RwLock<NsLockMap>>,
180 volume: String,
181 paths: Vec<String>,
182 ops_id: String,
183}
184
185impl LocalLockInstance {
186 fn new(ns: Arc<RwLock<NsLockMap>>, volume: String, paths: Vec<String>) -> Self {
187 let ops_id = Uuid::new_v4().to_string();
188 Self {
189 ns,
190 volume,
191 paths,
192 ops_id,
193 }
194 }
195}
196
197#[async_trait]
198impl RWLocker for LocalLockInstance {
199 async fn get_lock(&mut self, opts: &Options) -> Result<bool> {
200 let source = "".to_string();
201 let read_lock = false;
202 let mut success = vec![false; self.paths.len()];
203 for (idx, path) in self.paths.iter().enumerate() {
204 if !self
205 .ns
206 .write()
207 .await
208 .lock(&self.volume, path, &source, &self.ops_id, read_lock, opts.timeout)
209 .await
210 {
211 for (i, x) in success.iter().enumerate() {
212 if *x {
213 self.ns.write().await.un_lock(&self.volume, &self.paths[i], read_lock).await;
214 }
215 }
216
217 return Ok(false);
218 }
219
220 success[idx] = true;
221 }
222 Ok(true)
223 }
224
225 async fn un_lock(&mut self) -> Result<()> {
226 let read_lock = false;
227 for path in self.paths.iter() {
228 self.ns.write().await.un_lock(&self.volume, path, read_lock).await;
229 }
230
231 Ok(())
232 }
233
234 async fn get_u_lock(&mut self, opts: &Options) -> Result<bool> {
235 let source = "".to_string();
236 let read_lock = true;
237 let mut success = Vec::with_capacity(self.paths.len());
238 for (idx, path) in self.paths.iter().enumerate() {
239 if !self
240 .ns
241 .write()
242 .await
243 .lock(&self.volume, path, &source, &self.ops_id, read_lock, opts.timeout)
244 .await
245 {
246 for (i, x) in success.iter().enumerate() {
247 if *x {
248 self.ns.write().await.un_lock(&self.volume, &self.paths[i], read_lock).await;
249 }
250 }
251
252 return Ok(false);
253 }
254
255 success[idx] = true;
256 }
257 Ok(true)
258 }
259
260 async fn un_r_lock(&mut self) -> Result<()> {
261 let read_lock = true;
262 for path in self.paths.iter() {
263 self.ns.write().await.un_lock(&self.volume, path, read_lock).await;
264 }
265
266 Ok(())
267 }
268}
269
270#[cfg(test)]
271mod test {
272 use std::{sync::Arc, time::Duration};
273
274 use std::io::Result;
275 use tokio::sync::RwLock;
276
277 use crate::{
278 drwmutex::Options,
279 namespace_lock::{NsLockMap, new_nslock},
280 };
281
282 #[tokio::test]
283 async fn test_local_instance() -> Result<()> {
284 let ns_lock_map = Arc::new(RwLock::new(NsLockMap::default()));
285 let ns = new_nslock(
286 Arc::clone(&ns_lock_map),
287 "local".to_string(),
288 "test".to_string(),
289 vec!["foo".to_string()],
290 Vec::new(),
291 )
292 .await;
293
294 let result =
295 ns.0.write()
296 .await
297 .get_lock(&Options {
298 timeout: Duration::from_secs(5),
299 retry_interval: Duration::from_secs(1),
300 })
301 .await?;
302
303 assert!(result);
304 Ok(())
305 }
306}