use serde::{
de::{self, SeqAccess, Visitor},
ser::SerializeTuple,
Deserialize, Serialize,
};
use std::{
fmt,
hash::{Hash, Hasher as StdHasher},
ops::Range,
};
use rs_merkle::{algorithms::Sha256, Hasher, MerkleProof};
#[derive(
Default, Debug, Copy, Clone, Eq, PartialEq, Serialize, Deserialize,
)]
pub struct CommitHash(#[serde(with = "hex::serde")] pub [u8; 32]);
impl AsRef<[u8; 32]> for CommitHash {
fn as_ref(&self) -> &[u8; 32] {
&self.0
}
}
impl CommitHash {
pub fn to_bytes(&self) -> [u8; 32] {
self.0
}
}
impl From<CommitHash> for [u8; 32] {
fn from(value: CommitHash) -> Self {
value.0
}
}
impl fmt::Display for CommitHash {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", hex::encode(self.0))
}
}
#[derive(Debug, Eq, PartialEq)]
pub enum Comparison {
Equal,
Contains(Vec<usize>, Vec<[u8; 32]>),
Unknown,
}
pub struct CommitProof {
pub root: <Sha256 as Hasher>::Hash,
pub proof: MerkleProof<Sha256>,
pub length: usize,
pub indices: Range<usize>,
}
impl Hash for CommitProof {
fn hash<H: StdHasher>(&self, state: &mut H) {
self.root.hash(state);
self.proof.proof_hashes().hash(state);
self.length.hash(state);
self.indices.hash(state);
}
}
impl PartialEq for CommitProof {
fn eq(&self, other: &Self) -> bool {
self.root == other.root
&& self.proof.proof_hashes() == other.proof.proof_hashes()
&& self.length == other.length
&& self.indices == other.indices
}
}
impl Eq for CommitProof {}
impl Clone for CommitProof {
fn clone(&self) -> Self {
let hashes = self.proof.proof_hashes().to_vec();
CommitProof {
root: self.root,
proof: MerkleProof::<Sha256>::new(hashes),
length: self.length,
indices: self.indices.clone(),
}
}
}
impl CommitProof {
pub fn root(&self) -> &<Sha256 as Hasher>::Hash {
&self.root
}
pub fn root_hex(&self) -> String {
hex::encode(self.root)
}
pub fn len(&self) -> usize {
self.length
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl From<CommitProof> for (CommitHash, usize) {
fn from(value: CommitProof) -> Self {
(CommitHash(value.root), value.length)
}
}
impl Default for CommitProof {
fn default() -> Self {
Self {
root: [0; 32],
proof: MerkleProof::<Sha256>::new(vec![]),
length: 0,
indices: 0..0,
}
}
}
impl fmt::Debug for CommitProof {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CommitProof")
.field("root", &hex::encode(self.root))
.field("size", &self.length)
.field("leaves", &self.indices)
.finish()
}
}
impl serde::Serialize for CommitProof {
fn serialize<S>(
&self,
serializer: S,
) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut tup = serializer.serialize_tuple(4)?;
let root_hash = hex::encode(self.root);
tup.serialize_element(&root_hash)?;
let hashes = self.proof.proof_hashes();
tup.serialize_element(hashes)?;
tup.serialize_element(&self.length)?;
tup.serialize_element(&self.indices)?;
tup.end()
}
}
struct CommitProofVisitor;
impl<'de> Visitor<'de> for CommitProofVisitor {
type Value = CommitProof;
fn expecting(&self, _formatter: &mut fmt::Formatter) -> fmt::Result {
Ok(())
}
fn visit_seq<A>(
self,
mut seq: A,
) -> std::result::Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let root_hash: String = seq.next_element()?.ok_or_else(|| {
de::Error::custom("expecting a root hash for commit proof")
})?;
let root_hash = hex::decode(root_hash).map_err(de::Error::custom)?;
let root_hash: [u8; 32] =
root_hash.as_slice().try_into().map_err(de::Error::custom)?;
let hashes: Vec<[u8; 32]> = seq.next_element()?.ok_or_else(|| {
de::Error::custom("expecting sequence of proof hashes")
})?;
let length: usize = seq.next_element()?.ok_or_else(|| {
de::Error::custom("expecting tree length usize")
})?;
let indices: Range<usize> = seq
.next_element()?
.ok_or_else(|| de::Error::custom("expecting leaf node range"))?;
Ok(CommitProof {
root: root_hash,
proof: MerkleProof::new(hashes),
length,
indices,
})
}
}
impl<'de> serde::Deserialize<'de> for CommitProof {
fn deserialize<D>(
deserializer: D,
) -> std::result::Result<CommitProof, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_tuple(4, CommitProofVisitor)
}
}
#[derive(Debug, Hash, Eq, PartialEq)]
pub struct CommitPair {
pub local: CommitProof,
pub remote: CommitProof,
}
#[derive(Debug)]
pub enum CommitRelationship {
Equal(CommitPair),
Ahead(CommitPair, usize),
Behind(CommitPair, usize),
Diverged(CommitPair),
}
impl fmt::Display for CommitRelationship {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Equal(_) => {
write!(f, "up to date")
}
Self::Behind(_, diff) => {
write!(f, "{} change(s) behind remote: pull changes", diff)
}
Self::Ahead(_, diff) => {
write!(f, "{} change(s) ahead of remote: push changes", diff)
}
Self::Diverged(_) => {
write!(f, "local and remote have diverged: force push or force pull to synchronize trees")
}
}
}
}
impl CommitRelationship {
pub fn pair(&self) -> &CommitPair {
match self {
Self::Equal(pair) | Self::Diverged(pair) => pair,
Self::Behind(pair, _) | Self::Ahead(pair, _) => pair,
}
}
}