1use crate::error::Error;
2use crate::repo::{DataStore, PinKind, PinMode, PinModeRequirement, PinStore};
3use async_trait::async_trait;
4use futures::StreamExt;
5use ipld_core::cid::{self, Cid};
6use std::path::PathBuf;
7use tokio::sync::{Mutex, OwnedMutexGuard};
8
9use std::collections::hash_map::Entry;
10
11use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::Arc;
15
16#[derive(Debug, Default)]
18pub struct MemDataStore {
19 inner: Mutex<HashMap<Vec<u8>, Vec<u8>>>,
20 pin: Arc<Mutex<HashMap<Vec<u8>, Vec<u8>>>>,
23}
24
25impl MemDataStore {
26 pub fn new(_: PathBuf) -> Self {
27 Default::default()
28 }
29
30 fn insert_pin<'a>(
32 g: &mut OwnedMutexGuard<HashMap<Vec<u8>, Vec<u8>>>,
33 target: &'a Cid,
34 kind: &'a PinKind<&'_ Cid>,
35 ) -> Result<bool, Error> {
36 let key = target.to_bytes();
40
41 match g.entry(key) {
42 Entry::Occupied(mut oe) => {
43 let mut doc: PinDocument = serde_json::from_slice(oe.get())?;
44 if doc.update(true, kind)? {
45 let vec = oe.get_mut();
46 vec.clear();
47 serde_json::to_writer(vec, &doc)?;
48 trace!(doc = ?doc, kind = ?kind, "updated on insert");
49 Ok(true)
50 } else {
51 trace!(doc = ?doc, kind = ?kind, "update not needed on insert");
52 Ok(false)
53 }
54 }
55 Entry::Vacant(ve) => {
56 let mut doc = PinDocument {
57 version: 0,
58 direct: false,
59 recursive: Recursive::Not,
60 cid_version: match target.version() {
61 cid::Version::V0 => 0,
62 cid::Version::V1 => 1,
63 },
64 indirect_by: Vec::new(),
65 };
66
67 doc.update(true, kind).unwrap();
68 let vec = serde_json::to_vec(&doc)?;
69 ve.insert(vec);
70 trace!(doc = ?doc, kind = ?kind, "created on insert");
71 Ok(true)
72 }
73 }
74 }
75
76 fn remove_pin<'a>(
78 g: &mut OwnedMutexGuard<HashMap<Vec<u8>, Vec<u8>>>,
79 target: &'a Cid,
80 kind: &'a PinKind<&'_ Cid>,
81 ) -> Result<bool, Error> {
82 let key = target.to_bytes();
84
85 match g.entry(key) {
86 Entry::Occupied(mut oe) => {
87 let mut doc: PinDocument = serde_json::from_slice(oe.get())?;
88 if !doc.update(false, kind)? {
89 trace!(doc = ?doc, kind = ?kind, "update not needed on removal");
90 return Ok(false);
91 }
92
93 if doc.can_remove() {
94 oe.remove();
95 } else {
96 let vec = oe.get_mut();
97 vec.clear();
98 serde_json::to_writer(vec, &doc)?;
99 }
100
101 Ok(true)
102 }
103 Entry::Vacant(_) => Err(anyhow::anyhow!("not pinned")),
104 }
105 }
106}
107
108#[async_trait]
109impl PinStore for MemDataStore {
110 async fn is_pinned(&self, block: &Cid) -> Result<bool, Error> {
111 let key = block.to_bytes();
112
113 let g = self.pin.lock().await;
114
115 Ok(g.contains_key(&key))
123 }
124
125 async fn insert_direct_pin(&self, target: &Cid) -> Result<(), Error> {
126 let mut g = Mutex::lock_owned(Arc::clone(&self.pin)).await;
127 Self::insert_pin(&mut g, target, &PinKind::Direct)?;
128 Ok(())
129 }
130
131 async fn remove_direct_pin(&self, target: &Cid) -> Result<(), Error> {
132 let mut g = Mutex::lock_owned(Arc::clone(&self.pin)).await;
133 Self::remove_pin(&mut g, target, &PinKind::Direct)?;
134 Ok(())
135 }
136
137 async fn insert_recursive_pin(
138 &self,
139 target: &Cid,
140 mut refs: crate::repo::References<'_>,
141 ) -> Result<(), Error> {
142 use futures::stream::TryStreamExt;
143
144 let mut g = Mutex::lock_owned(Arc::clone(&self.pin)).await;
145
146 Self::insert_pin(&mut g, target, &PinKind::RecursiveIntention)?;
148
149 let target_v1 = if target.version() == cid::Version::V1 {
150 target.to_owned()
151 } else {
152 Cid::new_v1(target.codec(), target.hash().to_owned())
154 };
155
156 let mut count = 0;
161 let kind = PinKind::IndirectFrom(&target_v1);
162 while let Some(next) = refs.try_next().await? {
163 Self::insert_pin(&mut g, &next, &kind)?;
165 count += 1;
166 }
167
168 let kind = PinKind::Recursive(count as u64);
169 Self::insert_pin(&mut g, target, &kind)?;
170
171 Ok(())
172 }
173
174 async fn remove_recursive_pin(
175 &self,
176 target: &Cid,
177 mut refs: crate::repo::References<'_>,
178 ) -> Result<(), Error> {
179 use futures::TryStreamExt;
180
181 let mut g = Mutex::lock_owned(Arc::clone(&self.pin)).await;
182
183 let doc: PinDocument = match g.get(&target.to_bytes()) {
184 Some(raw) => match serde_json::from_slice(raw) {
185 Ok(doc) => doc,
186 Err(e) => return Err(e.into()),
187 },
188 None => return Err(anyhow::anyhow!("not pinned or pinned indirectly")),
190 };
191
192 let kind = match doc.pick_kind() {
193 Some(Ok(kind @ PinKind::Recursive(_)))
194 | Some(Ok(kind @ PinKind::RecursiveIntention)) => kind,
195 Some(Ok(PinKind::Direct)) => {
196 Self::remove_pin(&mut g, target, &PinKind::Direct)?;
197 return Ok(());
198 }
199 Some(Ok(PinKind::IndirectFrom(cid))) => {
200 return Err(anyhow::anyhow!("pinned indirectly through {}", cid))
201 }
202 _ => return Err(anyhow::anyhow!("not pinned or pinned indirectly")),
204 };
205
206 Self::remove_pin(&mut g, target, &kind.as_ref())?;
208
209 let target_v1 = if target.version() == cid::Version::V1 {
210 target.to_owned()
211 } else {
212 Cid::new_v1(target.codec(), target.hash().to_owned())
214 };
215
216 let kind = PinKind::IndirectFrom(&target_v1);
217 while let Some(next) = refs.try_next().await? {
218 Self::remove_pin(&mut g, &next, &kind)?;
220 }
221
222 Ok(())
223 }
224
225 async fn list(
226 &self,
227 requirement: Option<PinMode>,
228 ) -> futures::stream::BoxStream<'static, Result<(Cid, PinMode), Error>> {
229 use futures::stream::StreamExt;
230 use std::convert::TryFrom;
231 let g = self.pin.lock().await;
232
233 let requirement = PinModeRequirement::from(requirement);
234
235 let copy = g
236 .iter()
237 .map(|(key, value)| {
238 let cid = Cid::try_from(key.as_slice())?;
239 let doc: PinDocument = serde_json::from_slice(value)?;
240 let mode = doc.mode().ok_or_else(|| anyhow::anyhow!("invalid mode"))?;
241
242 Ok((cid, mode))
243 })
244 .filter(move |res| {
245 match res {
247 Ok((_, mode)) => requirement.matches(mode),
248 Err(_) => true,
249 }
250 })
251 .collect::<Vec<_>>();
252
253 futures::stream::iter(copy).boxed()
254 }
255
256 async fn query(
257 &self,
258 cids: Vec<Cid>,
259 requirement: Option<PinMode>,
260 ) -> Result<Vec<(Cid, PinKind<Cid>)>, Error> {
261 let g = self.pin.lock().await;
262
263 let requirement = PinModeRequirement::from(requirement);
264
265 cids.into_iter()
266 .map(move |cid| {
267 match g.get(&cid.to_bytes()) {
268 Some(raw) => {
269 let doc: PinDocument = match serde_json::from_slice(raw) {
270 Ok(doc) => doc,
271 Err(e) => return Err(e.into()),
272 };
273 let mode = match doc.pick_kind() {
276 Some(Ok(kind)) => kind,
277 Some(Err(invalid_cid)) => return Err(Error::new(invalid_cid)),
278 None => {
279 trace!(doc = ?doc, "could not pick pin kind");
280 return Err(anyhow::anyhow!("{} is not pinned", cid));
281 }
282 };
283
284 let matches = requirement.matches(&mode);
288
289 if matches {
290 trace!(cid = %cid, req = ?requirement, "pin matches");
291 return Ok((cid, mode));
292 } else {
293 return Err(anyhow::anyhow!(
295 "{} is not pinned as {:?}",
296 cid,
297 requirement
298 .required()
299 .expect("matches is never false if requirement is none")
300 ));
301 }
302 }
303 None => {
304 trace!(cid = %cid, "no record found");
305 }
306 }
307
308 Err(anyhow::anyhow!("{} is not pinned", cid))
310 })
311 .collect::<Result<Vec<_>, _>>()
312 }
313}
314
315#[async_trait]
316impl DataStore for MemDataStore {
317 async fn init(&self) -> Result<(), Error> {
318 Ok(())
319 }
320
321 async fn contains(&self, key: &[u8]) -> Result<bool, Error> {
322 let contains = self.inner.lock().await.contains_key(key);
323 Ok(contains)
324 }
325
326 async fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Error> {
327 let value = self
328 .inner
329 .lock()
330 .await
331 .get(key)
332 .map(|value| value.to_owned());
333 Ok(value)
334 }
335
336 async fn put(&self, key: &[u8], value: &[u8]) -> Result<(), Error> {
337 self.inner
338 .lock()
339 .await
340 .insert(key.to_owned(), value.to_owned());
341 Ok(())
342 }
343
344 async fn remove(&self, key: &[u8]) -> Result<(), Error> {
345 self.inner.lock().await.remove(key);
346 Ok(())
347 }
348
349 async fn iter(&self) -> futures::stream::BoxStream<'static, (Vec<u8>, Vec<u8>)> {
350 let list = self.inner.lock().await.clone();
351
352 let stream = async_stream::stream! {
353 for (k, v) in list {
354 yield (k, v)
355 }
356 };
357
358 stream.boxed()
359 }
360}
361
362#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
363enum Recursive {
364 Count(u64),
367 Intent,
372 Not,
374}
375
376impl Recursive {
377 fn is_set(&self) -> bool {
378 match self {
379 Recursive::Count(_) | Recursive::Intent => true,
380 Recursive::Not => false,
381 }
382 }
383}
384
385#[derive(Debug, Serialize, Deserialize)]
386struct PinDocument {
387 version: u8,
388 direct: bool,
389 recursive: Recursive,
391 cid_version: u8,
393 indirect_by: Vec<String>,
395}
396
397impl PinDocument {
398 fn update(&mut self, add: bool, kind: &PinKind<&'_ Cid>) -> Result<bool, PinUpdateError> {
399 match kind {
404 PinKind::IndirectFrom(root) => {
405 let root = if root.version() == cid::Version::V1 {
406 root.to_string()
407 } else {
408 Cid::new_v1(root.codec(), (*root).hash().to_owned()).to_string()
410 };
411
412 let modified = if self.indirect_by.is_empty() {
413 if add {
414 self.indirect_by.push(root);
415 true
416 } else {
417 false
418 }
419 } else {
420 let mut set = self
421 .indirect_by
422 .drain(..)
423 .collect::<std::collections::BTreeSet<_>>();
424
425 let modified = if add {
426 set.insert(root)
427 } else {
428 set.remove(&root)
429 };
430
431 self.indirect_by.extend(set);
432 modified
433 };
434
435 Ok(modified)
436 }
437 PinKind::Direct => {
438 if self.recursive.is_set() && !self.direct && add {
439 return Err(PinUpdateError::AlreadyPinnedRecursive);
442 }
443
444 if !add && !self.direct {
445 if !self.recursive.is_set() {
446 return Err(PinUpdateError::CannotUnpinUnpinned);
447 } else {
448 return Err(PinUpdateError::CannotUnpinDirectOnRecursivelyPinned);
449 }
450 }
451
452 let modified = self.direct != add;
453 self.direct = add;
454 Ok(modified)
455 }
456 PinKind::RecursiveIntention => {
457 let modified = if add {
458 match self.recursive {
459 Recursive::Count(_) => return Err(PinUpdateError::AlreadyPinnedRecursive),
460 Recursive::Intent => false,
463 Recursive::Not => {
464 self.recursive = Recursive::Intent;
465 self.direct = false;
466 true
467 }
468 }
469 } else {
470 match self.recursive {
471 Recursive::Count(_) | Recursive::Intent => {
472 self.recursive = Recursive::Not;
473 true
474 }
475 Recursive::Not => false,
476 }
477 };
478
479 Ok(modified)
480 }
481 PinKind::Recursive(descendants) => {
482 let descendants = *descendants;
483 let modified = if add {
484 match self.recursive {
485 Recursive::Count(other) if other != descendants => {
486 return Err(PinUpdateError::UnexpectedNumberOfDescendants(
487 other,
488 descendants,
489 ))
490 }
491 Recursive::Count(_) => false,
492 Recursive::Intent | Recursive::Not => {
493 self.recursive = Recursive::Count(descendants);
494 self.direct = false;
497 true
498 }
499 }
500 } else {
501 match self.recursive {
502 Recursive::Count(other) if other != descendants => {
503 return Err(PinUpdateError::UnexpectedNumberOfDescendants(
504 other,
505 descendants,
506 ))
507 }
508 Recursive::Count(_) | Recursive::Intent => {
509 self.recursive = Recursive::Not;
510 true
511 }
512 Recursive::Not => return Err(PinUpdateError::NotPinnedRecursive),
513 }
514 };
518 Ok(modified)
519 }
520 }
521 }
522
523 fn can_remove(&self) -> bool {
524 !self.direct && !self.recursive.is_set() && self.indirect_by.is_empty()
525 }
526
527 fn mode(&self) -> Option<PinMode> {
528 if self.recursive.is_set() {
529 Some(PinMode::Recursive)
530 } else if !self.indirect_by.is_empty() {
531 Some(PinMode::Indirect)
532 } else if self.direct {
533 Some(PinMode::Direct)
534 } else {
535 None
536 }
537 }
538
539 fn pick_kind(&self) -> Option<Result<PinKind<Cid>, cid::Error>> {
540 self.mode().map(|p| {
541 Ok(match p {
542 PinMode::Recursive => match self.recursive {
543 Recursive::Intent => PinKind::RecursiveIntention,
544 Recursive::Count(total) => PinKind::Recursive(total),
545 _ => unreachable!("mode should not have returned PinKind::Recursive"),
546 },
547 PinMode::Indirect => {
548 let cid = Cid::try_from(self.indirect_by[0].as_str())?;
552 PinKind::IndirectFrom(cid)
553 }
554 PinMode::Direct => PinKind::Direct,
555 })
556 })
557 }
558}
559
560#[derive(Debug, thiserror::Error)]
562pub enum PinUpdateError {
563 #[error("unexpected number of descendants ({}), found {}", .1, .0)]
565 UnexpectedNumberOfDescendants(u64, u64),
566 #[error("not pinned recursively")]
568 NotPinnedRecursive,
569 #[error("already pinned recursively")]
571 AlreadyPinnedRecursive,
572 #[error("not pinned or pinned indirectly")]
574 CannotUnpinUnpinned,
575 #[error("is pinned recursively")]
578 CannotUnpinDirectOnRecursivelyPinned,
579}
580
581#[cfg(test)]
582crate::pinstore_interface_tests!(
583 common_tests,
584 crate::repo::datastore::memory::MemDataStore::new
585);
586
587#[cfg(test)]
588mod tests {
589 use super::*;
590
591 #[tokio::test]
592 async fn test_mem_datastore() {
593 let tmp = std::env::temp_dir();
594 let store = MemDataStore::new(tmp);
595 let key = [1, 2, 3, 4];
596 let value = [5, 6, 7, 8];
597
598 store.init().await.unwrap();
599
600 let contains = store.contains(&key);
601 assert!(!contains.await.unwrap());
602 let get = store.get(&key);
603 assert_eq!(get.await.unwrap(), None);
604 store.remove(&key).await.unwrap();
605
606 let put = store.put(&key, &value);
607 put.await.unwrap();
608 let contains = store.contains(&key);
609 assert!(contains.await.unwrap());
610 let get = store.get(&key);
611 assert_eq!(get.await.unwrap(), Some(value.to_vec()));
612
613 store.remove(&key).await.unwrap();
614 let contains = store.contains(&key);
615 assert!(!contains.await.unwrap());
616 let get = store.get(&key);
617 assert_eq!(get.await.unwrap(), None);
618 }
619
620 #[test]
621 fn pindocument_on_direct_pin() {
622 let mut doc = PinDocument {
623 version: 0,
624 direct: false,
625 recursive: Recursive::Not,
626 cid_version: 0,
627 indirect_by: Vec::new(),
628 };
629
630 assert!(doc.update(true, &PinKind::Direct).unwrap());
631
632 assert_eq!(doc.mode(), Some(PinMode::Direct));
633 assert_eq!(doc.pick_kind().unwrap().unwrap(), PinKind::Direct);
634 }
635}