warg_server/datastore/
memory.rs

1use super::{DataStore, DataStoreError};
2use futures::Stream;
3use indexmap::{IndexMap, IndexSet};
4use std::{pin::Pin, sync::Arc};
5use tokio::sync::RwLock;
6use warg_crypto::{hash::AnyHash, Encode, Signable};
7use warg_protocol::{
8    operator,
9    package::{self, PackageEntry},
10    registry::{
11        LogId, LogLeaf, PackageName, RecordId, RegistryIndex, RegistryLen, TimestampedCheckpoint,
12    },
13    ProtoEnvelope, PublishedProtoEnvelope, SerdeEnvelope,
14};
15
16struct Entry<R> {
17    registry_index: RegistryIndex,
18    record_content: ProtoEnvelope<R>,
19}
20
21struct Log<S, R> {
22    state: S,
23    entries: Vec<Entry<R>>,
24}
25
26impl<S, R> Default for Log<S, R>
27where
28    S: Default,
29{
30    fn default() -> Self {
31        Self {
32            state: S::default(),
33            entries: Vec::new(),
34        }
35    }
36}
37
38struct Record {
39    /// Index in the log's entries.
40    index: usize,
41    /// Index in the registry's log.
42    registry_index: RegistryIndex,
43}
44
45enum PendingRecord {
46    Operator {
47        record: Option<ProtoEnvelope<operator::OperatorRecord>>,
48    },
49    Package {
50        record: Option<ProtoEnvelope<package::PackageRecord>>,
51        missing: IndexSet<AnyHash>,
52    },
53}
54
55enum RejectedRecord {
56    Operator {
57        record: ProtoEnvelope<operator::OperatorRecord>,
58        reason: String,
59    },
60    Package {
61        record: ProtoEnvelope<package::PackageRecord>,
62        reason: String,
63    },
64}
65
66enum RecordStatus {
67    Pending(PendingRecord),
68    Rejected(RejectedRecord),
69    Validated(Record),
70}
71
72#[derive(Default)]
73struct State {
74    operators: IndexMap<LogId, Log<operator::LogState, operator::OperatorRecord>>,
75    packages: IndexMap<LogId, Log<package::LogState, package::PackageRecord>>,
76    package_names: IndexMap<LogId, Option<PackageName>>,
77    checkpoints: IndexMap<RegistryLen, SerdeEnvelope<TimestampedCheckpoint>>,
78    records: IndexMap<LogId, IndexMap<RecordId, RecordStatus>>,
79    log_leafs: IndexMap<RegistryIndex, LogLeaf>,
80}
81
82/// Represents an in-memory data store.
83///
84/// Data is not persisted between restarts of the server.
85///
86/// Note: this is mainly used for testing, so it is not very efficient as
87/// it shares a single RwLock for all operations.
88pub struct MemoryDataStore(Arc<RwLock<State>>);
89
90impl MemoryDataStore {
91    pub fn new() -> Self {
92        Self(Arc::new(RwLock::new(State::default())))
93    }
94}
95
96impl Default for MemoryDataStore {
97    fn default() -> Self {
98        Self::new()
99    }
100}
101
102#[axum::async_trait]
103impl DataStore for MemoryDataStore {
104    async fn get_all_checkpoints(
105        &self,
106    ) -> Result<
107        Pin<Box<dyn Stream<Item = Result<TimestampedCheckpoint, DataStoreError>> + Send>>,
108        DataStoreError,
109    > {
110        Ok(Box::pin(futures::stream::empty()))
111    }
112
113    async fn get_all_validated_records(
114        &self,
115    ) -> Result<Pin<Box<dyn Stream<Item = Result<LogLeaf, DataStoreError>> + Send>>, DataStoreError>
116    {
117        Ok(Box::pin(futures::stream::empty()))
118    }
119
120    async fn get_log_leafs_starting_with_registry_index(
121        &self,
122        starting_index: RegistryIndex,
123        limit: usize,
124    ) -> Result<Vec<(RegistryIndex, LogLeaf)>, DataStoreError> {
125        let state = self.0.read().await;
126
127        let limit = if limit > state.log_leafs.len() - starting_index {
128            state.log_leafs.len() - starting_index
129        } else {
130            limit
131        };
132
133        let mut leafs = Vec::with_capacity(limit);
134        for entry in starting_index..starting_index + limit {
135            match state.log_leafs.get(&entry) {
136                Some(log_leaf) => leafs.push((entry, log_leaf.clone())),
137                None => break,
138            }
139        }
140
141        Ok(leafs)
142    }
143
144    async fn get_log_leafs_with_registry_index(
145        &self,
146        entries: &[RegistryIndex],
147    ) -> Result<Vec<LogLeaf>, DataStoreError> {
148        let state = self.0.read().await;
149
150        let mut leafs = Vec::with_capacity(entries.len());
151        for entry in entries {
152            match state.log_leafs.get(entry) {
153                Some(log_leaf) => leafs.push(log_leaf.clone()),
154                None => return Err(DataStoreError::LogLeafNotFound(*entry)),
155            }
156        }
157
158        Ok(leafs)
159    }
160
161    async fn get_package_names(
162        &self,
163        log_ids: &[LogId],
164    ) -> Result<IndexMap<LogId, Option<PackageName>>, DataStoreError> {
165        let state = self.0.read().await;
166
167        log_ids
168            .iter()
169            .map(|log_id| {
170                if let Some(opt_package_name) = state.package_names.get(log_id) {
171                    Ok((log_id.clone(), opt_package_name.clone()))
172                } else {
173                    Err(DataStoreError::LogNotFound(log_id.clone()))
174                }
175            })
176            .collect::<Result<IndexMap<LogId, Option<PackageName>>, _>>()
177    }
178
179    async fn store_operator_record(
180        &self,
181        log_id: &LogId,
182        record_id: &RecordId,
183        record: &ProtoEnvelope<operator::OperatorRecord>,
184    ) -> Result<(), DataStoreError> {
185        let mut state = self.0.write().await;
186        let prev = state.records.entry(log_id.clone()).or_default().insert(
187            record_id.clone(),
188            RecordStatus::Pending(PendingRecord::Operator {
189                record: Some(record.clone()),
190            }),
191        );
192
193        assert!(prev.is_none());
194        Ok(())
195    }
196
197    async fn reject_operator_record(
198        &self,
199        log_id: &LogId,
200        record_id: &RecordId,
201        reason: &str,
202    ) -> Result<(), DataStoreError> {
203        let mut state = self.0.write().await;
204
205        let status = state
206            .records
207            .get_mut(log_id)
208            .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?
209            .get_mut(record_id)
210            .ok_or_else(|| DataStoreError::RecordNotFound(record_id.clone()))?;
211
212        let record = match status {
213            RecordStatus::Pending(PendingRecord::Operator { record }) => record.take().unwrap(),
214            _ => return Err(DataStoreError::RecordNotPending(record_id.clone())),
215        };
216
217        *status = RecordStatus::Rejected(RejectedRecord::Operator {
218            record,
219            reason: reason.to_string(),
220        });
221
222        Ok(())
223    }
224
225    async fn commit_operator_record(
226        &self,
227        log_id: &LogId,
228        record_id: &RecordId,
229        registry_index: RegistryIndex,
230    ) -> Result<(), DataStoreError> {
231        let mut state = self.0.write().await;
232
233        let State {
234            operators,
235            records,
236            log_leafs,
237            ..
238        } = &mut *state;
239
240        let status = records
241            .get_mut(log_id)
242            .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?
243            .get_mut(record_id)
244            .ok_or_else(|| DataStoreError::RecordNotFound(record_id.clone()))?;
245
246        match status {
247            RecordStatus::Pending(PendingRecord::Operator { record }) => {
248                let record = record.take().unwrap();
249                let log = operators.entry(log_id.clone()).or_default();
250                match log
251                    .state
252                    .clone()
253                    .validate(&record)
254                    .map_err(DataStoreError::from)
255                {
256                    Ok(s) => {
257                        log.state = s;
258                        let index = log.entries.len();
259                        log.entries.push(Entry {
260                            registry_index,
261                            record_content: record,
262                        });
263                        *status = RecordStatus::Validated(Record {
264                            index,
265                            registry_index,
266                        });
267                        log_leafs.insert(
268                            registry_index,
269                            LogLeaf {
270                                log_id: log_id.clone(),
271                                record_id: record_id.clone(),
272                            },
273                        );
274                        Ok(())
275                    }
276                    Err(e) => {
277                        *status = RecordStatus::Rejected(RejectedRecord::Operator {
278                            record,
279                            reason: e.to_string(),
280                        });
281                        Err(e)
282                    }
283                }
284            }
285            _ => Err(DataStoreError::RecordNotPending(record_id.clone())),
286        }
287    }
288
289    async fn store_package_record(
290        &self,
291        log_id: &LogId,
292        package_name: &PackageName,
293        record_id: &RecordId,
294        record: &ProtoEnvelope<package::PackageRecord>,
295        missing: &IndexSet<&AnyHash>,
296    ) -> Result<(), DataStoreError> {
297        // Ensure the set of missing hashes is a subset of the record contents.
298        debug_assert!({
299            use warg_protocol::Record;
300            let contents = record.as_ref().contents();
301            missing.is_subset(&contents)
302        });
303
304        let mut state = self.0.write().await;
305        let prev = state.records.entry(log_id.clone()).or_default().insert(
306            record_id.clone(),
307            RecordStatus::Pending(PendingRecord::Package {
308                record: Some(record.clone()),
309                missing: missing.iter().map(|&d| d.clone()).collect(),
310            }),
311        );
312        state
313            .package_names
314            .insert(log_id.clone(), Some(package_name.clone()));
315
316        assert!(prev.is_none());
317        Ok(())
318    }
319
320    async fn reject_package_record(
321        &self,
322        log_id: &LogId,
323        record_id: &RecordId,
324        reason: &str,
325    ) -> Result<(), DataStoreError> {
326        let mut state = self.0.write().await;
327
328        let status = state
329            .records
330            .get_mut(log_id)
331            .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?
332            .get_mut(record_id)
333            .ok_or_else(|| DataStoreError::RecordNotFound(record_id.clone()))?;
334
335        let record = match status {
336            RecordStatus::Pending(PendingRecord::Package { record, .. }) => record.take().unwrap(),
337            _ => return Err(DataStoreError::RecordNotPending(record_id.clone())),
338        };
339
340        *status = RecordStatus::Rejected(RejectedRecord::Package {
341            record,
342            reason: reason.to_string(),
343        });
344
345        Ok(())
346    }
347
348    async fn commit_package_record(
349        &self,
350        log_id: &LogId,
351        record_id: &RecordId,
352        registry_index: RegistryIndex,
353    ) -> Result<(), DataStoreError> {
354        let mut state = self.0.write().await;
355
356        let State {
357            packages,
358            records,
359            log_leafs,
360            ..
361        } = &mut *state;
362
363        let status = records
364            .get_mut(log_id)
365            .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?
366            .get_mut(record_id)
367            .ok_or_else(|| DataStoreError::RecordNotFound(record_id.clone()))?;
368
369        match status {
370            RecordStatus::Pending(PendingRecord::Package { record, .. }) => {
371                let record = record.take().unwrap();
372                let log = packages.entry(log_id.clone()).or_default();
373                match log
374                    .state
375                    .clone()
376                    .validate(&record)
377                    .map_err(DataStoreError::from)
378                {
379                    Ok(state) => {
380                        log.state = state;
381                        let index = log.entries.len();
382                        log.entries.push(Entry {
383                            registry_index,
384                            record_content: record,
385                        });
386                        *status = RecordStatus::Validated(Record {
387                            index,
388                            registry_index,
389                        });
390                        log_leafs.insert(
391                            registry_index,
392                            LogLeaf {
393                                log_id: log_id.clone(),
394                                record_id: record_id.clone(),
395                            },
396                        );
397                        Ok(())
398                    }
399                    Err(e) => {
400                        *status = RecordStatus::Rejected(RejectedRecord::Package {
401                            record,
402                            reason: e.to_string(),
403                        });
404                        Err(e)
405                    }
406                }
407            }
408            _ => Err(DataStoreError::RecordNotPending(record_id.clone())),
409        }
410    }
411
412    async fn is_content_missing(
413        &self,
414        log_id: &LogId,
415        record_id: &RecordId,
416        digest: &AnyHash,
417    ) -> Result<bool, DataStoreError> {
418        let state = self.0.read().await;
419        let log = state
420            .records
421            .get(log_id)
422            .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?;
423
424        let status = log
425            .get(record_id)
426            .ok_or_else(|| DataStoreError::RecordNotFound(record_id.clone()))?;
427
428        match status {
429            RecordStatus::Pending(PendingRecord::Operator { .. }) => {
430                // Operator records have no content
431                Ok(false)
432            }
433            RecordStatus::Pending(PendingRecord::Package { missing, .. }) => {
434                Ok(missing.contains(digest))
435            }
436            _ => return Err(DataStoreError::RecordNotPending(record_id.clone())),
437        }
438    }
439
440    async fn set_content_present(
441        &self,
442        log_id: &LogId,
443        record_id: &RecordId,
444        digest: &AnyHash,
445    ) -> Result<bool, DataStoreError> {
446        let mut state = self.0.write().await;
447        let log = state
448            .records
449            .get_mut(log_id)
450            .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?;
451
452        let status = log
453            .get_mut(record_id)
454            .ok_or_else(|| DataStoreError::RecordNotFound(record_id.clone()))?;
455
456        match status {
457            RecordStatus::Pending(PendingRecord::Operator { .. }) => {
458                // Operator records have no content, so conceptually already present
459                Ok(false)
460            }
461            RecordStatus::Pending(PendingRecord::Package { missing, .. }) => {
462                if missing.is_empty() {
463                    return Ok(false);
464                }
465
466                // Return true if this was the last missing content
467                missing.swap_remove(digest);
468                Ok(missing.is_empty())
469            }
470            _ => return Err(DataStoreError::RecordNotPending(record_id.clone())),
471        }
472    }
473
474    async fn store_checkpoint(
475        &self,
476        _checkpoint_id: &AnyHash,
477        ts_checkpoint: SerdeEnvelope<TimestampedCheckpoint>,
478    ) -> Result<(), DataStoreError> {
479        let mut state = self.0.write().await;
480
481        state
482            .checkpoints
483            .insert(ts_checkpoint.as_ref().checkpoint.log_length, ts_checkpoint);
484
485        Ok(())
486    }
487
488    async fn get_latest_checkpoint(
489        &self,
490    ) -> Result<SerdeEnvelope<TimestampedCheckpoint>, DataStoreError> {
491        let state = self.0.read().await;
492        let checkpoint = state.checkpoints.values().last().unwrap();
493        Ok(checkpoint.clone())
494    }
495
496    async fn get_checkpoint(
497        &self,
498        log_length: RegistryLen,
499    ) -> Result<SerdeEnvelope<TimestampedCheckpoint>, DataStoreError> {
500        let state = self.0.read().await;
501        let checkpoint = state
502            .checkpoints
503            .get(&log_length)
504            .ok_or_else(|| DataStoreError::CheckpointNotFound(log_length))?;
505        Ok(checkpoint.clone())
506    }
507
508    async fn get_operator_records(
509        &self,
510        log_id: &LogId,
511        registry_log_length: RegistryLen,
512        since: Option<&RecordId>,
513        limit: u16,
514    ) -> Result<Vec<PublishedProtoEnvelope<operator::OperatorRecord>>, DataStoreError> {
515        let state = self.0.read().await;
516
517        let log = state
518            .operators
519            .get(log_id)
520            .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?;
521
522        if !state.checkpoints.contains_key(&registry_log_length) {
523            return Err(DataStoreError::CheckpointNotFound(registry_log_length));
524        };
525
526        let start_log_idx = match since {
527            Some(since) => match &state.records[log_id][since] {
528                RecordStatus::Validated(record) => record.index + 1,
529                _ => unreachable!(),
530            },
531            None => 0,
532        };
533
534        Ok(log
535            .entries
536            .iter()
537            .skip(start_log_idx)
538            .take_while(|entry| entry.registry_index < registry_log_length)
539            .map(|entry| PublishedProtoEnvelope {
540                envelope: entry.record_content.clone(),
541                registry_index: entry.registry_index,
542            })
543            .take(limit as usize)
544            .collect())
545    }
546
547    async fn get_package_records(
548        &self,
549        log_id: &LogId,
550        registry_log_length: RegistryLen,
551        since: Option<&RecordId>,
552        limit: u16,
553    ) -> Result<Vec<PublishedProtoEnvelope<package::PackageRecord>>, DataStoreError> {
554        let state = self.0.read().await;
555
556        let log = state
557            .packages
558            .get(log_id)
559            .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?;
560
561        if !state.checkpoints.contains_key(&registry_log_length) {
562            return Err(DataStoreError::CheckpointNotFound(registry_log_length));
563        };
564
565        let start_log_idx = match since {
566            Some(since) => match &state.records[log_id][since] {
567                RecordStatus::Validated(record) => record.index + 1,
568                _ => unreachable!(),
569            },
570            None => 0,
571        };
572
573        Ok(log
574            .entries
575            .iter()
576            .skip(start_log_idx)
577            .take_while(|entry| entry.registry_index < registry_log_length)
578            .map(|entry| PublishedProtoEnvelope {
579                envelope: entry.record_content.clone(),
580                registry_index: entry.registry_index,
581            })
582            .take(limit as usize)
583            .collect())
584    }
585
586    async fn get_operator_record(
587        &self,
588        log_id: &LogId,
589        record_id: &RecordId,
590    ) -> Result<super::Record<operator::OperatorRecord>, DataStoreError> {
591        let state = self.0.read().await;
592        let status = state
593            .records
594            .get(log_id)
595            .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?
596            .get(record_id)
597            .ok_or_else(|| DataStoreError::RecordNotFound(record_id.clone()))?;
598
599        let (status, envelope, registry_index) = match status {
600            RecordStatus::Pending(PendingRecord::Operator { record, .. }) => {
601                (super::RecordStatus::Pending, record.clone().unwrap(), None)
602            }
603            RecordStatus::Rejected(RejectedRecord::Operator { record, reason }) => (
604                super::RecordStatus::Rejected(reason.into()),
605                record.clone(),
606                None,
607            ),
608            RecordStatus::Validated(r) => {
609                let log = state
610                    .operators
611                    .get(log_id)
612                    .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?;
613
614                let published_length = state
615                    .checkpoints
616                    .last()
617                    .map(|(_, c)| c.as_ref().checkpoint.log_length)
618                    .unwrap_or_default();
619
620                (
621                    if r.registry_index < published_length {
622                        super::RecordStatus::Published
623                    } else {
624                        super::RecordStatus::Validated
625                    },
626                    log.entries[r.index].record_content.clone(),
627                    Some(r.registry_index),
628                )
629            }
630            _ => return Err(DataStoreError::RecordNotFound(record_id.clone())),
631        };
632
633        Ok(super::Record {
634            status,
635            envelope,
636            registry_index,
637        })
638    }
639
640    async fn get_package_record(
641        &self,
642        log_id: &LogId,
643        record_id: &RecordId,
644    ) -> Result<super::Record<package::PackageRecord>, DataStoreError> {
645        let state = self.0.read().await;
646        let status = state
647            .records
648            .get(log_id)
649            .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?
650            .get(record_id)
651            .ok_or_else(|| DataStoreError::RecordNotFound(record_id.clone()))?;
652
653        let (status, envelope, registry_index) = match status {
654            RecordStatus::Pending(PendingRecord::Package { record, .. }) => {
655                (super::RecordStatus::Pending, record.clone().unwrap(), None)
656            }
657            RecordStatus::Rejected(RejectedRecord::Package { record, reason }) => (
658                super::RecordStatus::Rejected(reason.into()),
659                record.clone(),
660                None,
661            ),
662            RecordStatus::Validated(r) => {
663                let log = state
664                    .packages
665                    .get(log_id)
666                    .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?;
667
668                let published_length = state
669                    .checkpoints
670                    .last()
671                    .map(|(_, c)| c.as_ref().checkpoint.log_length)
672                    .unwrap_or_default();
673
674                (
675                    if r.registry_index < published_length {
676                        super::RecordStatus::Published
677                    } else {
678                        super::RecordStatus::Validated
679                    },
680                    log.entries[r.index].record_content.clone(),
681                    Some(r.registry_index),
682                )
683            }
684            _ => return Err(DataStoreError::RecordNotFound(record_id.clone())),
685        };
686
687        Ok(super::Record {
688            status,
689            envelope,
690            registry_index,
691        })
692    }
693
694    async fn verify_package_record_signature(
695        &self,
696        log_id: &LogId,
697        record: &ProtoEnvelope<package::PackageRecord>,
698    ) -> Result<(), DataStoreError> {
699        let state = self.0.read().await;
700        let key = match state
701            .packages
702            .get(log_id)
703            .and_then(|log| log.state.public_key(record.key_id()))
704        {
705            Some(key) => Some(key),
706            None => match record.as_ref().entries.first() {
707                Some(PackageEntry::Init { key, .. }) => Some(key),
708                _ => return Err(DataStoreError::UnknownKey(record.key_id().clone())),
709            },
710        }
711        .ok_or_else(|| DataStoreError::UnknownKey(record.key_id().clone()))?;
712
713        package::PackageRecord::verify(key, record.content_bytes(), record.signature())
714            .map_err(|_| DataStoreError::SignatureVerificationFailed(record.signature().clone()))
715    }
716
717    async fn verify_can_publish_package(
718        &self,
719        operator_log_id: &LogId,
720        package_name: &PackageName,
721    ) -> Result<(), DataStoreError> {
722        let state = self.0.read().await;
723
724        // verify namespace is defined and not imported
725        match state
726            .operators
727            .get(operator_log_id)
728            .ok_or_else(|| DataStoreError::LogNotFound(operator_log_id.clone()))?
729            .state
730            .namespace_state(package_name.namespace())
731        {
732            Some(state) => match state {
733                operator::NamespaceState::Defined => {}
734                operator::NamespaceState::Imported { .. } => {
735                    return Err(DataStoreError::PackageNamespaceImported(
736                        package_name.namespace().to_string(),
737                    ))
738                }
739            },
740            None => {
741                return Err(DataStoreError::PackageNamespaceNotDefined(
742                    package_name.namespace().to_string(),
743                ))
744            }
745        }
746
747        Ok(())
748    }
749
750    async fn verify_timestamped_checkpoint_signature(
751        &self,
752        operator_log_id: &LogId,
753        ts_checkpoint: &SerdeEnvelope<TimestampedCheckpoint>,
754    ) -> Result<(), DataStoreError> {
755        let state = self.0.read().await;
756
757        let state = &state
758            .operators
759            .get(operator_log_id)
760            .ok_or_else(|| DataStoreError::LogNotFound(operator_log_id.clone()))?
761            .state;
762
763        TimestampedCheckpoint::verify(
764            state
765                .public_key(ts_checkpoint.key_id())
766                .ok_or(DataStoreError::UnknownKey(ts_checkpoint.key_id().clone()))?,
767            &ts_checkpoint.as_ref().encode(),
768            ts_checkpoint.signature(),
769        )
770        .or(Err(DataStoreError::SignatureVerificationFailed(
771            ts_checkpoint.signature().clone(),
772        )))?;
773
774        if !state.key_has_permission_to_sign_checkpoints(ts_checkpoint.key_id()) {
775            return Err(DataStoreError::KeyUnauthorized(
776                ts_checkpoint.key_id().clone(),
777            ));
778        }
779
780        Ok(())
781    }
782
783    #[cfg(feature = "debug")]
784    async fn debug_list_package_names(&self) -> anyhow::Result<Vec<PackageName>> {
785        let state = self.0.read().await;
786        Ok(state
787            .package_names
788            .values()
789            .filter_map(|opt_package_name| opt_package_name.clone())
790            .collect())
791    }
792}