sos_core/commit/
proof.rs

1//! Types that encapsulate commit proofs and comparisons.
2use super::TreeHash;
3use rs_merkle::{algorithms::Sha256, MerkleProof};
4use serde::{Deserialize, Serialize};
5use std::{
6    fmt,
7    hash::{Hash, Hasher as StdHasher},
8    str::FromStr,
9};
10
11/// Hash representation that provides a hexadecimal display.
12#[derive(
13    Default, Copy, Clone, Eq, PartialEq, Serialize, Deserialize, Hash,
14)]
15pub struct CommitHash(#[serde(with = "hex::serde")] pub TreeHash);
16
17impl fmt::Debug for CommitHash {
18    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19        f.debug_tuple("CommitHash")
20            .field(&self.to_string())
21            .finish()
22    }
23}
24
25impl AsRef<TreeHash> for CommitHash {
26    fn as_ref(&self) -> &TreeHash {
27        &self.0
28    }
29}
30
31impl From<CommitHash> for [u8; 32] {
32    fn from(value: CommitHash) -> Self {
33        value.0
34    }
35}
36
37impl From<&CommitHash> for [u8; 32] {
38    fn from(value: &CommitHash) -> Self {
39        value.0
40    }
41}
42
43impl fmt::Display for CommitHash {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        write!(f, "{}", hex::encode(self.0))
46    }
47}
48
49impl FromStr for CommitHash {
50    type Err = crate::Error;
51
52    fn from_str(value: &str) -> Result<Self, Self::Err> {
53        let value = hex::decode(value)?;
54        let value: TreeHash = value.as_slice().try_into()?;
55        Ok(Self(value))
56    }
57}
58
59/// The result of comparing two commit trees.
60///
61/// Either the trees are equal, the other tree
62/// is a subset of this tree or the trees completely
63/// diverge.
64#[derive(Default, Debug, Clone, Eq, PartialEq)]
65pub enum Comparison {
66    /// Trees are equal as their root commits match.
67    Equal,
68    /// Tree contains the other proof and returns
69    /// the indices that matched.
70    Contains(Vec<usize>),
71    /// Unable to find a match against the proof.
72    #[default]
73    Unknown,
74}
75
76mod proof_serde {
77    use rs_merkle::{algorithms::Sha256, MerkleProof};
78    use serde::{
79        de::{Deserialize, Deserializer, Error},
80        Serializer,
81    };
82    use std::borrow::Cow;
83
84    pub fn serialize<S>(
85        proof: &MerkleProof<Sha256>,
86        serializer: S,
87    ) -> Result<S::Ok, S::Error>
88    where
89        S: Serializer,
90    {
91        serializer.serialize_str(&hex::encode(&proof.to_bytes()))
92    }
93
94    pub fn deserialize<'de, D>(
95        deserializer: D,
96    ) -> Result<MerkleProof<Sha256>, D::Error>
97    where
98        D: Deserializer<'de>,
99    {
100        <Cow<'de, str> as Deserialize<'de>>::deserialize(deserializer)
101            .and_then(|s| hex::decode(&*s).map_err(Error::custom))
102            .and_then(|b| {
103                MerkleProof::<Sha256>::from_bytes(&b).map_err(Error::custom)
104            })
105    }
106}
107
108/// Represents a root hash and a proof of certain nodes.
109#[derive(Serialize, Deserialize)]
110pub struct CommitProof {
111    /// Root hash.
112    pub root: CommitHash,
113    /// Merkle proof.
114    #[serde(with = "proof_serde")]
115    pub proof: MerkleProof<Sha256>,
116    /// Length of the tree.
117    pub length: usize,
118    /// Indices to prove.
119    pub indices: Vec<usize>,
120}
121
122impl Hash for CommitProof {
123    fn hash<H: StdHasher>(&self, state: &mut H) {
124        self.root.hash(state);
125        self.proof.proof_hashes().hash(state);
126        self.length.hash(state);
127        self.indices.hash(state);
128    }
129}
130
131impl PartialEq for CommitProof {
132    fn eq(&self, other: &Self) -> bool {
133        self.root == other.root
134            && self.proof.proof_hashes() == other.proof.proof_hashes()
135            && self.length == other.length
136            && self.indices == other.indices
137    }
138}
139
140impl Eq for CommitProof {}
141
142impl Clone for CommitProof {
143    fn clone(&self) -> Self {
144        let hashes = self.proof.proof_hashes().to_vec();
145        CommitProof {
146            root: self.root,
147            proof: MerkleProof::<Sha256>::new(hashes),
148            length: self.length,
149            indices: self.indices.clone(),
150        }
151    }
152}
153
154impl CommitProof {
155    /// Root hash for the proof.
156    pub fn root(&self) -> &CommitHash {
157        &self.root
158    }
159
160    /// Number of leaves in the commit tree.
161    pub fn len(&self) -> usize {
162        self.length
163    }
164
165    /// Determine if this proof is empty.
166    pub fn is_empty(&self) -> bool {
167        self.len() == 0
168    }
169
170    /// Verify the indices of this proof using
171    /// a slice of leaves.
172    pub fn verify_leaves(
173        &self,
174        leaves: &[TreeHash],
175    ) -> (bool, Vec<TreeHash>) {
176        let leaves_to_prove = self
177            .indices
178            .iter()
179            .filter_map(|i| leaves.get(*i))
180            .copied()
181            .collect::<Vec<_>>();
182        (
183            self.proof.verify(
184                self.root().into(),
185                &self.indices,
186                leaves_to_prove.as_slice(),
187                leaves.len(),
188            ),
189            leaves_to_prove,
190        )
191    }
192}
193
194impl From<CommitProof> for (CommitHash, usize) {
195    fn from(value: CommitProof) -> Self {
196        (value.root, value.length)
197    }
198}
199
200impl From<CommitProof> for CommitHash {
201    fn from(value: CommitProof) -> Self {
202        value.root
203    }
204}
205
206impl Default for CommitProof {
207    fn default() -> Self {
208        Self {
209            root: Default::default(),
210            proof: MerkleProof::<Sha256>::new(vec![]),
211            length: 0,
212            indices: vec![],
213        }
214    }
215}
216
217impl fmt::Debug for CommitProof {
218    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219        f.debug_struct("CommitProof")
220            .field("root", &self.root.to_string())
221            //.field("proofs", self.1.proof_hashes())
222            .field("length", &self.length)
223            .field("indices", &self.indices)
224            .finish()
225    }
226}