willow_data_model/storage/
memory_store.rs

1use core::fmt::Debug;
2
3use std::{collections::BTreeMap, rc::Rc};
4
5use bytes::{Bytes, BytesMut};
6use ufotofu::{prelude::*, producer::clone_from_owned_slice};
7
8use frugal_async::Mutex;
9
10use crate::{
11    prelude::*,
12    storage::{CreateEntryError, NondestructiveInsert},
13};
14
15/// A non-persistent, in-memory Willow store implementation.
16#[derive(Debug)]
17pub struct MemoryStore<const MCL: usize, const MCC: usize, const MPL: usize, N, S, PD, AT> {
18    rc: Rc<Mutex<Store_<MCL, MCC, MPL, N, S, PD, AT>>>,
19}
20
21impl<const MCL: usize, const MCC: usize, const MPL: usize, N, S, PD, AT>
22    MemoryStore<MCL, MCC, MPL, N, S, PD, AT>
23{
24    /// Creates a new, empty Willow store.
25    pub fn new() -> Self {
26        MemoryStore {
27            rc: Rc::new(Mutex::new(Store_::new())),
28        }
29    }
30}
31
32impl<const MCL: usize, const MCC: usize, const MPL: usize, N, S, PD, AT> Clone
33    for MemoryStore<MCL, MCC, MPL, N, S, PD, AT>
34{
35    fn clone(&self) -> Self {
36        Self {
37            rc: self.rc.clone(),
38        }
39    }
40}
41
42impl<const MCL: usize, const MCC: usize, const MPL: usize, N, S, PD, AT>
43    MemoryStore<MCL, MCC, MPL, N, S, PD, AT>
44where
45    PD: Clone + Ord,
46    N: Clone,
47    S: Clone,
48{
49    async fn create_entry_impl<Data, P, H>(
50        &mut self,
51        data: &Data,
52        mut payload_producer: P,
53        payload_length: u64,
54        ingredients: &AT::Ingredients,
55        nondestructive: bool,
56    ) -> Result<
57        NondestructiveInsert<MCL, MCC, MPL, N, S, PD, AT>,
58        CreateEntryError<Infallible, P::Error, AT::CreationError>,
59    >
60    where
61        Data: ?Sized + Namespaced<N> + Coordinatelike<MCL, MCC, MPL, S>,
62        P: BulkProducer<Item = u8, Final = ()>,
63        H: Default + anyhash::Hasher<PD>,
64        AT: AuthorisationToken<MCL, MCC, MPL, N, S, PD> + Debug + Clone,
65        N: Debug + Clone + Ord,
66        S: Debug + Clone + Ord,
67        PD: Debug,
68    {
69        let mut payload = BytesMut::with_capacity(payload_length as usize);
70
71        let payload = loop {
72            match payload_producer
73                .expose_items_sync(|items| {
74                    payload.extend_from_slice(items);
75                    (items.len(), ())
76                })
77                .await
78            {
79                Ok(Left(())) => {}
80                Ok(Right(())) => {
81                    if payload.len() as u64 == payload_length {
82                        break payload.freeze();
83                    } else {
84                        return Err(CreateEntryError::IncorrectPayloadLength);
85                    }
86                }
87                Err(err) => return Err(CreateEntryError::ProducerError(err)),
88            }
89        };
90
91        let entry: Entry<MCL, MCC, MPL, N, S, PD> = Entry::builder()
92            .namespace_id(data.wdm_namespace_id().clone())
93            .subspace_id(data.wdm_subspace_id().clone())
94            .path(data.wdm_path().clone())
95            .timestamp(data.wdm_timestamp())
96            .payload::<_, H, PD>(&payload)
97            .build()
98            .unwrap();
99
100        let entry = entry
101            .into_authorised_entry(ingredients)
102            .map_err(|err| CreateEntryError::AuthorisationTokenError(err))?;
103
104        let mut store = self.rc.write().await;
105
106        Ok(store.do_insert_entry(entry, nondestructive, Some(payload)))
107    }
108}
109
110#[derive(Debug)]
111struct Store_<const MCL: usize, const MCC: usize, const MPL: usize, N, S, PD, AT> {
112    namespaces: BTreeMap<N, NamespaceStore<MCL, MCC, MPL, S, PD, AT>>,
113}
114
115impl<const MCL: usize, const MCC: usize, const MPL: usize, N, S, PD, AT>
116    Store_<MCL, MCC, MPL, N, S, PD, AT>
117{
118    fn new() -> Self {
119        Store_ {
120            namespaces: BTreeMap::new(),
121        }
122    }
123
124    fn get_or_create_namespace_store(
125        &mut self,
126        namespace_id: &N,
127    ) -> &mut NamespaceStore<MCL, MCC, MPL, S, PD, AT>
128    where
129        N: Ord + Clone,
130    {
131        if !self.namespaces.contains_key(namespace_id) {
132            let _ = self
133                .namespaces
134                .insert(namespace_id.clone(), NamespaceStore::new());
135        }
136
137        self.namespaces.get_mut(namespace_id).unwrap()
138    }
139
140    fn do_insert_entry(
141        &mut self,
142        authorised_entry: AuthorisedEntry<MCL, MCC, MPL, N, S, PD, AT>,
143        prevent_pruning: bool,
144        payload: Option<Bytes>,
145    ) -> NondestructiveInsert<MCL, MCC, MPL, N, S, PD, AT>
146    where
147        N: Ord + Clone,
148        S: Ord + Clone,
149        PD: Ord + Clone,
150        AT: Clone,
151    {
152        let namespace_store =
153            self.get_or_create_namespace_store(authorised_entry.wdm_namespace_id());
154
155        let subspace_store =
156            namespace_store.get_or_create_subspace_store(authorised_entry.wdm_subspace_id());
157
158        // Is the inserted entry redundant? If so, return early.
159        for (path, entry) in subspace_store.entries.iter() {
160            if path.is_prefix_of(authorised_entry.wdm_path())
161                && entry.is_newer_than(authorised_entry.entry())
162            {
163                return NondestructiveInsert::Outdated;
164            }
165        }
166
167        if subspace_store.handle_insertion(&authorised_entry, prevent_pruning, true, payload) {
168            NondestructiveInsert::Prevented
169        } else {
170            NondestructiveInsert::Success(authorised_entry)
171        }
172    }
173}
174
175#[derive(Debug)]
176struct NamespaceStore<const MCL: usize, const MCC: usize, const MPL: usize, S, PD, AT> {
177    subspaces: BTreeMap<S, SubspaceStore<MCL, MCC, MPL, PD, AT>>,
178}
179
180impl<const MCL: usize, const MCC: usize, const MPL: usize, S, PD, AT>
181    NamespaceStore<MCL, MCC, MPL, S, PD, AT>
182{
183    fn new() -> Self {
184        NamespaceStore {
185            subspaces: BTreeMap::new(),
186        }
187    }
188
189    fn get_or_create_subspace_store(
190        &mut self,
191        subspace_id: &S,
192    ) -> &mut SubspaceStore<MCL, MCC, MPL, PD, AT>
193    where
194        S: Ord + Clone,
195    {
196        if !self.subspaces.contains_key(subspace_id) {
197            let _ = self
198                .subspaces
199                .insert(subspace_id.clone(), SubspaceStore::new());
200        }
201
202        self.subspaces.get_mut(subspace_id).unwrap()
203    }
204}
205
206#[derive(Debug)]
207struct SubspaceStore<const MCL: usize, const MCC: usize, const MPL: usize, PD, AT> {
208    entries: BTreeMap<Path<MCL, MCC, MPL>, ControlEntry<PD, AT>>,
209}
210
211impl<const MCL: usize, const MCC: usize, const MPL: usize, PD, AT>
212    SubspaceStore<MCL, MCC, MPL, PD, AT>
213{
214    fn new() -> Self {
215        Self {
216            entries: BTreeMap::new(),
217        }
218    }
219}
220
221impl<const MCL: usize, const MCC: usize, const MPL: usize, PD, AT>
222    SubspaceStore<MCL, MCC, MPL, PD, AT>
223where
224    PD: Ord + Clone,
225    AT: Clone,
226{
227    // Returns `true` if this would have pruned anything but the prevent_pruning flag was set.
228    // Inserts the entry if `do_insert_if_necessary` is true, and insertion is allowed (i.e., pruning does not need to be prevented).
229    fn handle_insertion<N, S>(
230        &mut self,
231        new_entry: &AuthorisedEntry<MCL, MCC, MPL, N, S, PD, AT>,
232        prevent_pruning: bool,
233        do_insert_if_necessary: bool,
234        payload: Option<Bytes>,
235    ) -> bool {
236        // Does the new entry replace others?
237        let prune_these: Vec<_> = self
238            .entries
239            .iter()
240            .filter_map(|(path, entry)| {
241                if new_entry.wdm_path().is_prefix_of(path)
242                    && !entry.is_newer_than(new_entry.entry())
243                {
244                    Some(path.clone())
245                } else {
246                    None
247                }
248            })
249            .collect();
250
251        if prevent_pruning && !prune_these.is_empty() {
252            return true;
253        } else {
254            for path_to_prune in prune_these {
255                self.entries.remove(&path_to_prune);
256            }
257
258            if do_insert_if_necessary {
259                self.entries.insert(
260                    new_entry.wdm_path().clone(),
261                    ControlEntry {
262                        authorisation_token: new_entry.authorisation_token().clone(),
263                        payload: payload.unwrap_or(Bytes::new()),
264                        payload_digest: new_entry.wdm_payload_digest().clone(),
265                        payload_length: new_entry.wdm_payload_length(),
266                        timestamp: new_entry.wdm_timestamp(),
267                    },
268                );
269            }
270
271            return false;
272        }
273    }
274}
275
276#[derive(Debug, Clone)]
277struct ControlEntry<PD, AT> {
278    timestamp: Timestamp,
279    payload_length: u64,
280    payload_digest: PD,
281    authorisation_token: AT,
282    payload: Bytes,
283}
284
285impl<PD, AT> ControlEntry<PD, AT>
286where
287    PD: Ord,
288{
289    /// [newer than relation](https://willowprotocol.org/specs/data-model/index.html#entry_newer)
290    fn is_newer_than<const MCL: usize, const MCC: usize, const MPL: usize, N, S>(
291        &self,
292        entry: &Entry<MCL, MCC, MPL, N, S, PD>,
293    ) -> bool {
294        entry.wdm_timestamp() < self.timestamp
295            || (entry.wdm_timestamp() == self.timestamp
296                && *entry.wdm_payload_digest() < self.payload_digest)
297            || (entry.wdm_timestamp() == self.timestamp
298                && *entry.wdm_payload_digest() == self.payload_digest
299                && entry.wdm_payload_length() < self.payload_length)
300    }
301}
302
303impl<const MCL: usize, const MCC: usize, const MPL: usize, N, S, PD, AT>
304    Store<MCL, MCC, MPL, N, S, PD, AT> for MemoryStore<MCL, MCC, MPL, N, S, PD, AT>
305where
306    PD: Clone + Ord,
307    N: Clone + Ord,
308    S: Clone + Ord,
309    AT: Clone,
310{
311    type InternalError = Infallible;
312
313    async fn create_entry<Data, P, H>(
314        &mut self,
315        data: &Data,
316        payload_producer: P,
317        payload_length: u64,
318        ingredients: &AT::Ingredients,
319    ) -> Result<
320        Option<AuthorisedEntry<MCL, MCC, MPL, N, S, PD, AT>>,
321        CreateEntryError<Self::InternalError, P::Error, AT::CreationError>,
322    >
323    where
324        Data: ?Sized + Namespaced<N> + Coordinatelike<MCL, MCC, MPL, S>,
325        P: BulkProducer<Item = u8, Final = ()>,
326        H: Default + anyhash::Hasher<PD>,
327        AT: AuthorisationToken<MCL, MCC, MPL, N, S, PD> + Debug,
328        N: Debug,
329        S: Debug,
330        PD: Debug,
331    {
332        match self
333            .create_entry_impl::<Data, P, H>(
334                data,
335                payload_producer,
336                payload_length,
337                ingredients,
338                false,
339            )
340            .await?
341        {
342            NondestructiveInsert::Success(yay) => Ok(Some(yay)),
343            NondestructiveInsert::Outdated => Ok(None),
344            NondestructiveInsert::Prevented => unreachable!(),
345        }
346    }
347
348    async fn create_entry_nondestructive<Data, P, H>(
349        &mut self,
350        data: &Data,
351        payload_producer: P,
352        payload_length: u64,
353        ingredients: &AT::Ingredients,
354    ) -> Result<
355        NondestructiveInsert<MCL, MCC, MPL, N, S, PD, AT>,
356        CreateEntryError<Self::InternalError, P::Error, AT::CreationError>,
357    >
358    where
359        Data: ?Sized + Namespaced<N> + Coordinatelike<MCL, MCC, MPL, S>,
360        P: BulkProducer<Item = u8, Final = ()>,
361        H: Default + anyhash::Hasher<PD>,
362        AT: AuthorisationToken<MCL, MCC, MPL, N, S, PD> + Debug,
363        N: Debug,
364        S: Debug,
365        PD: Debug,
366    {
367        self.create_entry_impl::<Data, P, H>(
368            data,
369            payload_producer,
370            payload_length,
371            ingredients,
372            true,
373        )
374        .await
375    }
376
377    async fn insert_entry(
378        &mut self,
379        entry: AuthorisedEntry<MCL, MCC, MPL, N, S, PD, AT>,
380    ) -> Result<bool, Self::InternalError> {
381        let mut store = self.rc.write().await;
382
383        match store.do_insert_entry(entry, false, None) {
384            NondestructiveInsert::Outdated => Ok(false),
385            NondestructiveInsert::Prevented => unreachable!(),
386            NondestructiveInsert::Success(_yay) => Ok(true),
387        }
388    }
389
390    async fn forget_entry<K>(
391        &mut self,
392        namespace_id: &N,
393        key: &K,
394        expected_digest: Option<PD>,
395    ) -> Result<bool, Self::InternalError>
396    where
397        K: Keylike<MCL, MCC, MPL, S>,
398        PD: PartialEq,
399    {
400        let mut store = self.rc.write().await;
401
402        let namespace_store = store.get_or_create_namespace_store(namespace_id);
403
404        let subspace_store = namespace_store.get_or_create_subspace_store(key.wdm_subspace_id());
405
406        let found = subspace_store.entries.get(key.wdm_path());
407
408        match found {
409            None => Ok(false),
410            Some(entry) => {
411                if let Some(expected) = expected_digest {
412                    if entry.payload_digest != expected {
413                        return Ok(false);
414                    }
415                }
416
417                Ok(subspace_store.entries.remove(key.wdm_path()).is_some())
418            }
419        }
420    }
421
422    async fn forget_area(
423        &mut self,
424        namespace_id: &N,
425        area: &Area<MCL, MCC, MPL, S>,
426    ) -> Result<(), Self::InternalError> {
427        let entries_to_forget = {
428            let mut store = self.rc.write().await;
429
430            let namespace_store = store.get_or_create_namespace_store(namespace_id);
431
432            // Here we collect all entries we need to forget.
433            let mut entries_to_forget = vec![];
434
435            match area.subspace() {
436                Some(subspace_id) => {
437                    // Collect entries to forget for a single-subspace area.
438                    let subspace_store = namespace_store.get_or_create_subspace_store(subspace_id);
439
440                    for (path, entry) in subspace_store.entries.iter() {
441                        if !area.times().includes_value(&entry.timestamp) {
442                            continue;
443                        }
444
445                        if path.is_prefixed_by(area.path()) {
446                            entries_to_forget.push((subspace_id.clone(), path.clone()));
447                        }
448                    }
449                }
450                None => {
451                    // Collect entries to forget for an all-subspaces area.
452                    for (subspace_id, subspace_store) in namespace_store.subspaces.iter() {
453                        for (path, entry) in subspace_store.entries.iter() {
454                            if !area.times().includes_value(&entry.timestamp) {
455                                continue;
456                            }
457
458                            if path.is_prefixed_by(area.path()) {
459                                entries_to_forget.push((subspace_id.clone(), path.clone()));
460                            }
461                        }
462                    }
463                }
464            }
465
466            entries_to_forget
467        };
468
469        // And finally, forget the collected entries.
470        for forget_this in entries_to_forget {
471            self.forget_entry(namespace_id, &forget_this, None)
472                .await
473                .expect("cannot fail when expected_digest is None");
474        }
475
476        Ok(())
477    }
478
479    async fn forget_namespace(&mut self, namespace_id: &N) -> Result<(), Self::InternalError> {
480        let mut store = self.rc.write().await;
481
482        store.namespaces.remove(namespace_id);
483
484        Ok(())
485    }
486
487    async fn get_entry<K, Slice>(
488        &mut self,
489        namespace_id: &N,
490        key: &K,
491        expected_digest: Option<PD>,
492        payload_slice: &Slice,
493    ) -> Result<
494        Option<(
495            AuthorisedEntry<MCL, MCC, MPL, N, S, PD, AT>,
496            impl BulkProducer<
497                Item = u8,
498                Final = (),
499                Error = super::PayloadProducerError<Self::InternalError>,
500            >,
501        )>,
502        Self::InternalError,
503    >
504    where
505        K: Keylike<MCL, MCC, MPL, S>,
506        Slice: std::ops::RangeBounds<u64>,
507    {
508        let mut store = self.rc.write().await;
509
510        let namespace_store = store.get_or_create_namespace_store(namespace_id);
511
512        let subspace_store = namespace_store.get_or_create_subspace_store(key.wdm_subspace_id());
513
514        match subspace_store.entries.get(key.wdm_path()) {
515            None => return Ok(None),
516            Some(found) => {
517                if let Some(expected) = expected_digest {
518                    if found.payload_digest != expected {
519                        return Ok(None);
520                    }
521                }
522
523                let entry = Entry::builder()
524                    .namespace_id(namespace_id.clone())
525                    .subspace_id(key.wdm_subspace_id().clone())
526                    .path(key.wdm_path().clone())
527                    .timestamp(found.timestamp)
528                    .payload_length(found.payload_length)
529                    .payload_digest(found.payload_digest.clone())
530                    .build()
531                    .unwrap();
532
533                // SAFETY: we checked authorisation upon insertion.
534                let authed = unsafe {
535                    PossiblyAuthorisedEntry {
536                        entry,
537                        authorisation_token: found.authorisation_token.clone(),
538                    }
539                    .into_authorised_entry_unchecked()
540                };
541
542                Ok(Some((
543                    authed,
544                    MapErr::new(
545                        clone_from_owned_slice(found.payload.slice((
546                            payload_slice.start_bound().map(|start| *start as usize),
547                            payload_slice.end_bound().map(|end| *end as usize),
548                        ))),
549                        |_| unreachable!(),
550                    ),
551                )))
552            }
553        }
554    }
555
556    async fn get_area(
557        &mut self,
558        namespace_id: N,
559        area: Area<MCL, MCC, MPL, S>,
560    ) -> impl Producer<
561        Item = AuthorisedEntry<MCL, MCC, MPL, N, S, PD, AT>,
562        Final = (),
563        Error = Self::InternalError,
564    > {
565        let entries_to_produce = {
566            let mut store = self.rc.write().await;
567
568            let namespace_store = store.get_or_create_namespace_store(&namespace_id);
569
570            // Here we collect all entries we need to produce.
571            let mut entries_to_produce = vec![];
572
573            match area.subspace() {
574                Some(subspace_id) => {
575                    // Collect entries to produce for a single-subspace area.
576                    let subspace_store = namespace_store.get_or_create_subspace_store(subspace_id);
577
578                    for (path, entry) in subspace_store.entries.iter() {
579                        if !area.times().includes_value(&entry.timestamp) {
580                            continue;
581                        }
582
583                        if path.is_prefixed_by(area.path()) {
584                            let new_entry = Entry::builder()
585                                .namespace_id(namespace_id.clone())
586                                .subspace_id(subspace_id.clone())
587                                .path(path.clone())
588                                .timestamp(entry.timestamp)
589                                .payload_length(entry.payload_length)
590                                .payload_digest(entry.payload_digest.clone())
591                                .build()
592                                .unwrap();
593
594                            // SAFETY: we checked authorisation upon insertion.
595                            let authed = unsafe {
596                                PossiblyAuthorisedEntry {
597                                    entry: new_entry,
598                                    authorisation_token: entry.authorisation_token.clone(),
599                                }
600                                .into_authorised_entry_unchecked()
601                            };
602
603                            entries_to_produce.push(authed);
604                        }
605                    }
606                }
607                None => {
608                    // Collect entries to produce for an all-subspaces area.
609                    for (subspace_id, subspace_store) in namespace_store.subspaces.iter() {
610                        for (path, entry) in subspace_store.entries.iter() {
611                            if !area.times().includes_value(&entry.timestamp) {
612                                continue;
613                            }
614
615                            if path.is_prefixed_by(area.path()) {
616                                let new_entry = Entry::builder()
617                                    .namespace_id(namespace_id.clone())
618                                    .subspace_id(subspace_id.clone())
619                                    .path(path.clone())
620                                    .timestamp(entry.timestamp)
621                                    .payload_length(entry.payload_length)
622                                    .payload_digest(entry.payload_digest.clone())
623                                    .build()
624                                    .unwrap();
625
626                                // SAFETY: we checked authorisation upon insertion.
627                                let authed = unsafe {
628                                    PossiblyAuthorisedEntry {
629                                        entry: new_entry,
630                                        authorisation_token: entry.authorisation_token.clone(),
631                                    }
632                                    .into_authorised_entry_unchecked()
633                                };
634
635                                entries_to_produce.push(authed);
636                            }
637                        }
638                    }
639                }
640            }
641
642            entries_to_produce
643        };
644
645        return MapErr::new(entries_to_produce.into_producer(), |_| unreachable!());
646    }
647
648    /// A no-op, because the store is not persistent.
649    async fn flush(&mut self) -> Result<(), Self::InternalError> {
650        Ok(())
651    }
652}
653
654/////////////////////////////////////////////////////////////////////////
655// TODO: publish the MapErr producer adaptor below as part of ufotofu. //
656/////////////////////////////////////////////////////////////////////////
657
658/// A bulk producer wrapper which changes the error (type) of the wrapped producer, by passing errors through a function.
659///
660/// Use the `AsRef<P>` impl to access the wrapped producer.
661///
662/// Created via [`ProducerExt::map_err`].
663///
664/// <br/>Counterpart: the [consumer::MapErr] type.
665#[derive(Debug)]
666
667pub struct MapErr<P, Fun> {
668    inner: P,
669    fun: Option<Fun>,
670}
671
672impl<P, Fun> MapErr<P, Fun> {
673    pub(crate) fn new(inner: P, fun: Fun) -> Self {
674        Self {
675            inner,
676            fun: Some(fun),
677        }
678    }
679
680    /// Consumes `self` and returns the wrapped producer.
681    pub fn into_inner(self) -> P {
682        self.inner
683    }
684}
685
686impl<P, Fun> AsRef<P> for MapErr<P, Fun> {
687    fn as_ref(&self) -> &P {
688        &self.inner
689    }
690}
691
692impl<P, Fun, OldErr, NewErr> Producer for MapErr<P, Fun>
693where
694    P: Producer<Error = OldErr>,
695    Fun: FnOnce(OldErr) -> NewErr,
696{
697    type Item = P::Item;
698    type Final = P::Final;
699    type Error = NewErr;
700
701    async fn produce(&mut self) -> Result<Either<Self::Item, Self::Final>, Self::Error> {
702        self.inner.produce().await.map_err(|err| {
703            (self
704                .fun
705                .take()
706                .expect("Must not use a producer after it emitted an error"))(err)
707        })
708    }
709
710    async fn slurp(&mut self) -> Result<(), Self::Error> {
711        self.inner.slurp().await.map_err(|err| {
712            (self
713                .fun
714                .take()
715                .expect("Must not use a producer after it emitted an error"))(err)
716        })
717    }
718}
719
720impl<P, Fun, OldErr, NewErr> BulkProducer for MapErr<P, Fun>
721where
722    P: BulkProducer<Error = OldErr>,
723    Fun: FnOnce(OldErr) -> NewErr,
724{
725    async fn expose_items<F, R>(&mut self, f: F) -> Result<Either<R, Self::Final>, Self::Error>
726    where
727        F: AsyncFnOnce(&[Self::Item]) -> (usize, R),
728    {
729        self.inner.expose_items(f).await.map_err(|err| {
730            (self
731                .fun
732                .take()
733                .expect("Must not use a producer after it emitted an error"))(err)
734        })
735    }
736}