1use super::{DataStore, DataStoreError};
2use futures::Stream;
3use indexmap::{IndexMap, IndexSet};
4use std::{pin::Pin, sync::Arc};
5use tokio::sync::RwLock;
6use warg_crypto::{hash::AnyHash, Encode, Signable};
7use warg_protocol::{
8 operator,
9 package::{self, PackageEntry},
10 registry::{
11 LogId, LogLeaf, PackageName, RecordId, RegistryIndex, RegistryLen, TimestampedCheckpoint,
12 },
13 ProtoEnvelope, PublishedProtoEnvelope, SerdeEnvelope,
14};
15
16struct Entry<R> {
17 registry_index: RegistryIndex,
18 record_content: ProtoEnvelope<R>,
19}
20
21struct Log<S, R> {
22 state: S,
23 entries: Vec<Entry<R>>,
24}
25
26impl<S, R> Default for Log<S, R>
27where
28 S: Default,
29{
30 fn default() -> Self {
31 Self {
32 state: S::default(),
33 entries: Vec::new(),
34 }
35 }
36}
37
38struct Record {
39 index: usize,
41 registry_index: RegistryIndex,
43}
44
45enum PendingRecord {
46 Operator {
47 record: Option<ProtoEnvelope<operator::OperatorRecord>>,
48 },
49 Package {
50 record: Option<ProtoEnvelope<package::PackageRecord>>,
51 missing: IndexSet<AnyHash>,
52 },
53}
54
55enum RejectedRecord {
56 Operator {
57 record: ProtoEnvelope<operator::OperatorRecord>,
58 reason: String,
59 },
60 Package {
61 record: ProtoEnvelope<package::PackageRecord>,
62 reason: String,
63 },
64}
65
66enum RecordStatus {
67 Pending(PendingRecord),
68 Rejected(RejectedRecord),
69 Validated(Record),
70}
71
72#[derive(Default)]
73struct State {
74 operators: IndexMap<LogId, Log<operator::LogState, operator::OperatorRecord>>,
75 packages: IndexMap<LogId, Log<package::LogState, package::PackageRecord>>,
76 package_names: IndexMap<LogId, Option<PackageName>>,
77 checkpoints: IndexMap<RegistryLen, SerdeEnvelope<TimestampedCheckpoint>>,
78 records: IndexMap<LogId, IndexMap<RecordId, RecordStatus>>,
79 log_leafs: IndexMap<RegistryIndex, LogLeaf>,
80}
81
82pub struct MemoryDataStore(Arc<RwLock<State>>);
89
90impl MemoryDataStore {
91 pub fn new() -> Self {
92 Self(Arc::new(RwLock::new(State::default())))
93 }
94}
95
96impl Default for MemoryDataStore {
97 fn default() -> Self {
98 Self::new()
99 }
100}
101
102#[axum::async_trait]
103impl DataStore for MemoryDataStore {
104 async fn get_all_checkpoints(
105 &self,
106 ) -> Result<
107 Pin<Box<dyn Stream<Item = Result<TimestampedCheckpoint, DataStoreError>> + Send>>,
108 DataStoreError,
109 > {
110 Ok(Box::pin(futures::stream::empty()))
111 }
112
113 async fn get_all_validated_records(
114 &self,
115 ) -> Result<Pin<Box<dyn Stream<Item = Result<LogLeaf, DataStoreError>> + Send>>, DataStoreError>
116 {
117 Ok(Box::pin(futures::stream::empty()))
118 }
119
120 async fn get_log_leafs_starting_with_registry_index(
121 &self,
122 starting_index: RegistryIndex,
123 limit: usize,
124 ) -> Result<Vec<(RegistryIndex, LogLeaf)>, DataStoreError> {
125 let state = self.0.read().await;
126
127 let limit = if limit > state.log_leafs.len() - starting_index {
128 state.log_leafs.len() - starting_index
129 } else {
130 limit
131 };
132
133 let mut leafs = Vec::with_capacity(limit);
134 for entry in starting_index..starting_index + limit {
135 match state.log_leafs.get(&entry) {
136 Some(log_leaf) => leafs.push((entry, log_leaf.clone())),
137 None => break,
138 }
139 }
140
141 Ok(leafs)
142 }
143
144 async fn get_log_leafs_with_registry_index(
145 &self,
146 entries: &[RegistryIndex],
147 ) -> Result<Vec<LogLeaf>, DataStoreError> {
148 let state = self.0.read().await;
149
150 let mut leafs = Vec::with_capacity(entries.len());
151 for entry in entries {
152 match state.log_leafs.get(entry) {
153 Some(log_leaf) => leafs.push(log_leaf.clone()),
154 None => return Err(DataStoreError::LogLeafNotFound(*entry)),
155 }
156 }
157
158 Ok(leafs)
159 }
160
161 async fn get_package_names(
162 &self,
163 log_ids: &[LogId],
164 ) -> Result<IndexMap<LogId, Option<PackageName>>, DataStoreError> {
165 let state = self.0.read().await;
166
167 log_ids
168 .iter()
169 .map(|log_id| {
170 if let Some(opt_package_name) = state.package_names.get(log_id) {
171 Ok((log_id.clone(), opt_package_name.clone()))
172 } else {
173 Err(DataStoreError::LogNotFound(log_id.clone()))
174 }
175 })
176 .collect::<Result<IndexMap<LogId, Option<PackageName>>, _>>()
177 }
178
179 async fn store_operator_record(
180 &self,
181 log_id: &LogId,
182 record_id: &RecordId,
183 record: &ProtoEnvelope<operator::OperatorRecord>,
184 ) -> Result<(), DataStoreError> {
185 let mut state = self.0.write().await;
186 let prev = state.records.entry(log_id.clone()).or_default().insert(
187 record_id.clone(),
188 RecordStatus::Pending(PendingRecord::Operator {
189 record: Some(record.clone()),
190 }),
191 );
192
193 assert!(prev.is_none());
194 Ok(())
195 }
196
197 async fn reject_operator_record(
198 &self,
199 log_id: &LogId,
200 record_id: &RecordId,
201 reason: &str,
202 ) -> Result<(), DataStoreError> {
203 let mut state = self.0.write().await;
204
205 let status = state
206 .records
207 .get_mut(log_id)
208 .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?
209 .get_mut(record_id)
210 .ok_or_else(|| DataStoreError::RecordNotFound(record_id.clone()))?;
211
212 let record = match status {
213 RecordStatus::Pending(PendingRecord::Operator { record }) => record.take().unwrap(),
214 _ => return Err(DataStoreError::RecordNotPending(record_id.clone())),
215 };
216
217 *status = RecordStatus::Rejected(RejectedRecord::Operator {
218 record,
219 reason: reason.to_string(),
220 });
221
222 Ok(())
223 }
224
225 async fn commit_operator_record(
226 &self,
227 log_id: &LogId,
228 record_id: &RecordId,
229 registry_index: RegistryIndex,
230 ) -> Result<(), DataStoreError> {
231 let mut state = self.0.write().await;
232
233 let State {
234 operators,
235 records,
236 log_leafs,
237 ..
238 } = &mut *state;
239
240 let status = records
241 .get_mut(log_id)
242 .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?
243 .get_mut(record_id)
244 .ok_or_else(|| DataStoreError::RecordNotFound(record_id.clone()))?;
245
246 match status {
247 RecordStatus::Pending(PendingRecord::Operator { record }) => {
248 let record = record.take().unwrap();
249 let log = operators.entry(log_id.clone()).or_default();
250 match log
251 .state
252 .clone()
253 .validate(&record)
254 .map_err(DataStoreError::from)
255 {
256 Ok(s) => {
257 log.state = s;
258 let index = log.entries.len();
259 log.entries.push(Entry {
260 registry_index,
261 record_content: record,
262 });
263 *status = RecordStatus::Validated(Record {
264 index,
265 registry_index,
266 });
267 log_leafs.insert(
268 registry_index,
269 LogLeaf {
270 log_id: log_id.clone(),
271 record_id: record_id.clone(),
272 },
273 );
274 Ok(())
275 }
276 Err(e) => {
277 *status = RecordStatus::Rejected(RejectedRecord::Operator {
278 record,
279 reason: e.to_string(),
280 });
281 Err(e)
282 }
283 }
284 }
285 _ => Err(DataStoreError::RecordNotPending(record_id.clone())),
286 }
287 }
288
289 async fn store_package_record(
290 &self,
291 log_id: &LogId,
292 package_name: &PackageName,
293 record_id: &RecordId,
294 record: &ProtoEnvelope<package::PackageRecord>,
295 missing: &IndexSet<&AnyHash>,
296 ) -> Result<(), DataStoreError> {
297 debug_assert!({
299 use warg_protocol::Record;
300 let contents = record.as_ref().contents();
301 missing.is_subset(&contents)
302 });
303
304 let mut state = self.0.write().await;
305 let prev = state.records.entry(log_id.clone()).or_default().insert(
306 record_id.clone(),
307 RecordStatus::Pending(PendingRecord::Package {
308 record: Some(record.clone()),
309 missing: missing.iter().map(|&d| d.clone()).collect(),
310 }),
311 );
312 state
313 .package_names
314 .insert(log_id.clone(), Some(package_name.clone()));
315
316 assert!(prev.is_none());
317 Ok(())
318 }
319
320 async fn reject_package_record(
321 &self,
322 log_id: &LogId,
323 record_id: &RecordId,
324 reason: &str,
325 ) -> Result<(), DataStoreError> {
326 let mut state = self.0.write().await;
327
328 let status = state
329 .records
330 .get_mut(log_id)
331 .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?
332 .get_mut(record_id)
333 .ok_or_else(|| DataStoreError::RecordNotFound(record_id.clone()))?;
334
335 let record = match status {
336 RecordStatus::Pending(PendingRecord::Package { record, .. }) => record.take().unwrap(),
337 _ => return Err(DataStoreError::RecordNotPending(record_id.clone())),
338 };
339
340 *status = RecordStatus::Rejected(RejectedRecord::Package {
341 record,
342 reason: reason.to_string(),
343 });
344
345 Ok(())
346 }
347
348 async fn commit_package_record(
349 &self,
350 log_id: &LogId,
351 record_id: &RecordId,
352 registry_index: RegistryIndex,
353 ) -> Result<(), DataStoreError> {
354 let mut state = self.0.write().await;
355
356 let State {
357 packages,
358 records,
359 log_leafs,
360 ..
361 } = &mut *state;
362
363 let status = records
364 .get_mut(log_id)
365 .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?
366 .get_mut(record_id)
367 .ok_or_else(|| DataStoreError::RecordNotFound(record_id.clone()))?;
368
369 match status {
370 RecordStatus::Pending(PendingRecord::Package { record, .. }) => {
371 let record = record.take().unwrap();
372 let log = packages.entry(log_id.clone()).or_default();
373 match log
374 .state
375 .clone()
376 .validate(&record)
377 .map_err(DataStoreError::from)
378 {
379 Ok(state) => {
380 log.state = state;
381 let index = log.entries.len();
382 log.entries.push(Entry {
383 registry_index,
384 record_content: record,
385 });
386 *status = RecordStatus::Validated(Record {
387 index,
388 registry_index,
389 });
390 log_leafs.insert(
391 registry_index,
392 LogLeaf {
393 log_id: log_id.clone(),
394 record_id: record_id.clone(),
395 },
396 );
397 Ok(())
398 }
399 Err(e) => {
400 *status = RecordStatus::Rejected(RejectedRecord::Package {
401 record,
402 reason: e.to_string(),
403 });
404 Err(e)
405 }
406 }
407 }
408 _ => Err(DataStoreError::RecordNotPending(record_id.clone())),
409 }
410 }
411
412 async fn is_content_missing(
413 &self,
414 log_id: &LogId,
415 record_id: &RecordId,
416 digest: &AnyHash,
417 ) -> Result<bool, DataStoreError> {
418 let state = self.0.read().await;
419 let log = state
420 .records
421 .get(log_id)
422 .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?;
423
424 let status = log
425 .get(record_id)
426 .ok_or_else(|| DataStoreError::RecordNotFound(record_id.clone()))?;
427
428 match status {
429 RecordStatus::Pending(PendingRecord::Operator { .. }) => {
430 Ok(false)
432 }
433 RecordStatus::Pending(PendingRecord::Package { missing, .. }) => {
434 Ok(missing.contains(digest))
435 }
436 _ => return Err(DataStoreError::RecordNotPending(record_id.clone())),
437 }
438 }
439
440 async fn set_content_present(
441 &self,
442 log_id: &LogId,
443 record_id: &RecordId,
444 digest: &AnyHash,
445 ) -> Result<bool, DataStoreError> {
446 let mut state = self.0.write().await;
447 let log = state
448 .records
449 .get_mut(log_id)
450 .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?;
451
452 let status = log
453 .get_mut(record_id)
454 .ok_or_else(|| DataStoreError::RecordNotFound(record_id.clone()))?;
455
456 match status {
457 RecordStatus::Pending(PendingRecord::Operator { .. }) => {
458 Ok(false)
460 }
461 RecordStatus::Pending(PendingRecord::Package { missing, .. }) => {
462 if missing.is_empty() {
463 return Ok(false);
464 }
465
466 missing.swap_remove(digest);
468 Ok(missing.is_empty())
469 }
470 _ => return Err(DataStoreError::RecordNotPending(record_id.clone())),
471 }
472 }
473
474 async fn store_checkpoint(
475 &self,
476 _checkpoint_id: &AnyHash,
477 ts_checkpoint: SerdeEnvelope<TimestampedCheckpoint>,
478 ) -> Result<(), DataStoreError> {
479 let mut state = self.0.write().await;
480
481 state
482 .checkpoints
483 .insert(ts_checkpoint.as_ref().checkpoint.log_length, ts_checkpoint);
484
485 Ok(())
486 }
487
488 async fn get_latest_checkpoint(
489 &self,
490 ) -> Result<SerdeEnvelope<TimestampedCheckpoint>, DataStoreError> {
491 let state = self.0.read().await;
492 let checkpoint = state.checkpoints.values().last().unwrap();
493 Ok(checkpoint.clone())
494 }
495
496 async fn get_checkpoint(
497 &self,
498 log_length: RegistryLen,
499 ) -> Result<SerdeEnvelope<TimestampedCheckpoint>, DataStoreError> {
500 let state = self.0.read().await;
501 let checkpoint = state
502 .checkpoints
503 .get(&log_length)
504 .ok_or_else(|| DataStoreError::CheckpointNotFound(log_length))?;
505 Ok(checkpoint.clone())
506 }
507
508 async fn get_operator_records(
509 &self,
510 log_id: &LogId,
511 registry_log_length: RegistryLen,
512 since: Option<&RecordId>,
513 limit: u16,
514 ) -> Result<Vec<PublishedProtoEnvelope<operator::OperatorRecord>>, DataStoreError> {
515 let state = self.0.read().await;
516
517 let log = state
518 .operators
519 .get(log_id)
520 .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?;
521
522 if !state.checkpoints.contains_key(®istry_log_length) {
523 return Err(DataStoreError::CheckpointNotFound(registry_log_length));
524 };
525
526 let start_log_idx = match since {
527 Some(since) => match &state.records[log_id][since] {
528 RecordStatus::Validated(record) => record.index + 1,
529 _ => unreachable!(),
530 },
531 None => 0,
532 };
533
534 Ok(log
535 .entries
536 .iter()
537 .skip(start_log_idx)
538 .take_while(|entry| entry.registry_index < registry_log_length)
539 .map(|entry| PublishedProtoEnvelope {
540 envelope: entry.record_content.clone(),
541 registry_index: entry.registry_index,
542 })
543 .take(limit as usize)
544 .collect())
545 }
546
547 async fn get_package_records(
548 &self,
549 log_id: &LogId,
550 registry_log_length: RegistryLen,
551 since: Option<&RecordId>,
552 limit: u16,
553 ) -> Result<Vec<PublishedProtoEnvelope<package::PackageRecord>>, DataStoreError> {
554 let state = self.0.read().await;
555
556 let log = state
557 .packages
558 .get(log_id)
559 .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?;
560
561 if !state.checkpoints.contains_key(®istry_log_length) {
562 return Err(DataStoreError::CheckpointNotFound(registry_log_length));
563 };
564
565 let start_log_idx = match since {
566 Some(since) => match &state.records[log_id][since] {
567 RecordStatus::Validated(record) => record.index + 1,
568 _ => unreachable!(),
569 },
570 None => 0,
571 };
572
573 Ok(log
574 .entries
575 .iter()
576 .skip(start_log_idx)
577 .take_while(|entry| entry.registry_index < registry_log_length)
578 .map(|entry| PublishedProtoEnvelope {
579 envelope: entry.record_content.clone(),
580 registry_index: entry.registry_index,
581 })
582 .take(limit as usize)
583 .collect())
584 }
585
586 async fn get_operator_record(
587 &self,
588 log_id: &LogId,
589 record_id: &RecordId,
590 ) -> Result<super::Record<operator::OperatorRecord>, DataStoreError> {
591 let state = self.0.read().await;
592 let status = state
593 .records
594 .get(log_id)
595 .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?
596 .get(record_id)
597 .ok_or_else(|| DataStoreError::RecordNotFound(record_id.clone()))?;
598
599 let (status, envelope, registry_index) = match status {
600 RecordStatus::Pending(PendingRecord::Operator { record, .. }) => {
601 (super::RecordStatus::Pending, record.clone().unwrap(), None)
602 }
603 RecordStatus::Rejected(RejectedRecord::Operator { record, reason }) => (
604 super::RecordStatus::Rejected(reason.into()),
605 record.clone(),
606 None,
607 ),
608 RecordStatus::Validated(r) => {
609 let log = state
610 .operators
611 .get(log_id)
612 .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?;
613
614 let published_length = state
615 .checkpoints
616 .last()
617 .map(|(_, c)| c.as_ref().checkpoint.log_length)
618 .unwrap_or_default();
619
620 (
621 if r.registry_index < published_length {
622 super::RecordStatus::Published
623 } else {
624 super::RecordStatus::Validated
625 },
626 log.entries[r.index].record_content.clone(),
627 Some(r.registry_index),
628 )
629 }
630 _ => return Err(DataStoreError::RecordNotFound(record_id.clone())),
631 };
632
633 Ok(super::Record {
634 status,
635 envelope,
636 registry_index,
637 })
638 }
639
640 async fn get_package_record(
641 &self,
642 log_id: &LogId,
643 record_id: &RecordId,
644 ) -> Result<super::Record<package::PackageRecord>, DataStoreError> {
645 let state = self.0.read().await;
646 let status = state
647 .records
648 .get(log_id)
649 .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?
650 .get(record_id)
651 .ok_or_else(|| DataStoreError::RecordNotFound(record_id.clone()))?;
652
653 let (status, envelope, registry_index) = match status {
654 RecordStatus::Pending(PendingRecord::Package { record, .. }) => {
655 (super::RecordStatus::Pending, record.clone().unwrap(), None)
656 }
657 RecordStatus::Rejected(RejectedRecord::Package { record, reason }) => (
658 super::RecordStatus::Rejected(reason.into()),
659 record.clone(),
660 None,
661 ),
662 RecordStatus::Validated(r) => {
663 let log = state
664 .packages
665 .get(log_id)
666 .ok_or_else(|| DataStoreError::LogNotFound(log_id.clone()))?;
667
668 let published_length = state
669 .checkpoints
670 .last()
671 .map(|(_, c)| c.as_ref().checkpoint.log_length)
672 .unwrap_or_default();
673
674 (
675 if r.registry_index < published_length {
676 super::RecordStatus::Published
677 } else {
678 super::RecordStatus::Validated
679 },
680 log.entries[r.index].record_content.clone(),
681 Some(r.registry_index),
682 )
683 }
684 _ => return Err(DataStoreError::RecordNotFound(record_id.clone())),
685 };
686
687 Ok(super::Record {
688 status,
689 envelope,
690 registry_index,
691 })
692 }
693
694 async fn verify_package_record_signature(
695 &self,
696 log_id: &LogId,
697 record: &ProtoEnvelope<package::PackageRecord>,
698 ) -> Result<(), DataStoreError> {
699 let state = self.0.read().await;
700 let key = match state
701 .packages
702 .get(log_id)
703 .and_then(|log| log.state.public_key(record.key_id()))
704 {
705 Some(key) => Some(key),
706 None => match record.as_ref().entries.first() {
707 Some(PackageEntry::Init { key, .. }) => Some(key),
708 _ => return Err(DataStoreError::UnknownKey(record.key_id().clone())),
709 },
710 }
711 .ok_or_else(|| DataStoreError::UnknownKey(record.key_id().clone()))?;
712
713 package::PackageRecord::verify(key, record.content_bytes(), record.signature())
714 .map_err(|_| DataStoreError::SignatureVerificationFailed(record.signature().clone()))
715 }
716
717 async fn verify_can_publish_package(
718 &self,
719 operator_log_id: &LogId,
720 package_name: &PackageName,
721 ) -> Result<(), DataStoreError> {
722 let state = self.0.read().await;
723
724 match state
726 .operators
727 .get(operator_log_id)
728 .ok_or_else(|| DataStoreError::LogNotFound(operator_log_id.clone()))?
729 .state
730 .namespace_state(package_name.namespace())
731 {
732 Some(state) => match state {
733 operator::NamespaceState::Defined => {}
734 operator::NamespaceState::Imported { .. } => {
735 return Err(DataStoreError::PackageNamespaceImported(
736 package_name.namespace().to_string(),
737 ))
738 }
739 },
740 None => {
741 return Err(DataStoreError::PackageNamespaceNotDefined(
742 package_name.namespace().to_string(),
743 ))
744 }
745 }
746
747 Ok(())
748 }
749
750 async fn verify_timestamped_checkpoint_signature(
751 &self,
752 operator_log_id: &LogId,
753 ts_checkpoint: &SerdeEnvelope<TimestampedCheckpoint>,
754 ) -> Result<(), DataStoreError> {
755 let state = self.0.read().await;
756
757 let state = &state
758 .operators
759 .get(operator_log_id)
760 .ok_or_else(|| DataStoreError::LogNotFound(operator_log_id.clone()))?
761 .state;
762
763 TimestampedCheckpoint::verify(
764 state
765 .public_key(ts_checkpoint.key_id())
766 .ok_or(DataStoreError::UnknownKey(ts_checkpoint.key_id().clone()))?,
767 &ts_checkpoint.as_ref().encode(),
768 ts_checkpoint.signature(),
769 )
770 .or(Err(DataStoreError::SignatureVerificationFailed(
771 ts_checkpoint.signature().clone(),
772 )))?;
773
774 if !state.key_has_permission_to_sign_checkpoints(ts_checkpoint.key_id()) {
775 return Err(DataStoreError::KeyUnauthorized(
776 ts_checkpoint.key_id().clone(),
777 ));
778 }
779
780 Ok(())
781 }
782
783 #[cfg(feature = "debug")]
784 async fn debug_list_package_names(&self) -> anyhow::Result<Vec<PackageName>> {
785 let state = self.0.read().await;
786 Ok(state
787 .package_names
788 .values()
789 .filter_map(|opt_package_name| opt_package_name.clone())
790 .collect())
791 }
792}