1use std::marker::PhantomData;
2use std::sync::Arc;
3
4use jmt::KeyHash;
5#[cfg(all(target_os = "zkvm", feature = "bench"))]
6use sov_zk_cycle_macros::cycle_tracker;
7
8use crate::internal_cache::OrderedReadsAndWrites;
9use crate::storage::{Storage, StorageKey, StorageProof, StorageValue};
10use crate::witness::Witness;
11use crate::MerkleProofSpec;
12
13#[cfg(all(target_os = "zkvm", feature = "bench"))]
14extern crate risc0_zkvm;
15
16#[derive(Default)]
18pub struct ZkStorage<S: MerkleProofSpec> {
19 _phantom_hasher: PhantomData<S::Hasher>,
20}
21
22impl<S: MerkleProofSpec> Clone for ZkStorage<S> {
23 fn clone(&self) -> Self {
24 Self {
25 _phantom_hasher: Default::default(),
26 }
27 }
28}
29
30impl<S: MerkleProofSpec> ZkStorage<S> {
31 pub fn new() -> Self {
33 Self {
34 _phantom_hasher: Default::default(),
35 }
36 }
37}
38
39impl<S: MerkleProofSpec> Storage for ZkStorage<S> {
40 type Witness = S::Witness;
41 type RuntimeConfig = ();
42 type Proof = jmt::proof::SparseMerkleProof<S::Hasher>;
43 type StateUpdate = ();
44 type Root = jmt::RootHash;
45
46 fn with_config(_config: Self::RuntimeConfig) -> Result<Self, anyhow::Error> {
47 Ok(Self::new())
48 }
49
50 fn get(&self, _key: &StorageKey, witness: &Self::Witness) -> Option<StorageValue> {
51 witness.get_hint()
52 }
53
54 #[cfg_attr(all(target_os = "zkvm", feature = "bench"), cycle_tracker)]
55 fn compute_state_update(
56 &self,
57 state_accesses: OrderedReadsAndWrites,
58 witness: &Self::Witness,
59 ) -> Result<(Self::Root, Self::StateUpdate), anyhow::Error> {
60 let prev_state_root = witness.get_hint();
61
62 for (key, read_value) in state_accesses.ordered_reads {
64 let key_hash = KeyHash::with::<S::Hasher>(key.key.as_ref());
65 let proof: jmt::proof::SparseMerkleProof<S::Hasher> = witness.get_hint();
67 match read_value {
68 Some(val) => proof.verify_existence(
69 jmt::RootHash(prev_state_root),
70 key_hash,
71 val.value.as_ref(),
72 )?,
73 None => proof.verify_nonexistence(jmt::RootHash(prev_state_root), key_hash)?,
74 }
75 }
76
77 let batch = state_accesses
79 .ordered_writes
80 .into_iter()
81 .map(|(key, value)| {
82 let key_hash = KeyHash::with::<S::Hasher>(key.key.as_ref());
83 (
84 key_hash,
85 value.map(|v| Arc::try_unwrap(v.value).unwrap_or_else(|arc| (*arc).clone())),
86 )
87 })
88 .collect::<Vec<_>>();
89
90 let update_proof: jmt::proof::UpdateMerkleProof<S::Hasher> = witness.get_hint();
91 let new_root: [u8; 32] = witness.get_hint();
92 update_proof
93 .verify_update(
94 jmt::RootHash(prev_state_root),
95 jmt::RootHash(new_root),
96 batch,
97 )
98 .expect("Updates must be valid");
99
100 Ok((jmt::RootHash(new_root), ()))
101 }
102
103 #[cfg_attr(all(target_os = "zkvm", feature = "bench"), cycle_tracker)]
104 fn commit(&self, _node_batch: &Self::StateUpdate, _accessory_writes: &OrderedReadsAndWrites) {}
105
106 fn is_empty(&self) -> bool {
107 unimplemented!("Needs simplification in JellyfishMerkleTree: https://github.com/Sovereign-Labs/sovereign-sdk/issues/362")
108 }
109
110 fn open_proof(
111 state_root: Self::Root,
112 state_proof: StorageProof<Self::Proof>,
113 ) -> Result<(StorageKey, Option<StorageValue>), anyhow::Error> {
114 let StorageProof { key, value, proof } = state_proof;
115 let key_hash = KeyHash::with::<S::Hasher>(key.as_ref());
116
117 proof.verify(state_root, key_hash, value.as_ref().map(|v| v.value()))?;
118 Ok((key, value))
119 }
120}