wagyu_model/
derivation_path.rs

1use std::{
2    fmt,
3    fmt::{Debug, Display},
4    str::FromStr,
5};
6
7/// The interface for a generic derivation path.
8pub trait DerivationPath: Clone + Debug + Display + FromStr + Send + Sync + 'static + Eq + Sized {
9    /// Returns a child index vector given the derivation path.
10    fn to_vec(&self) -> Result<Vec<ChildIndex>, DerivationPathError>;
11
12    /// Returns a derivation path given the child index vector.
13    fn from_vec(path: &Vec<ChildIndex>) -> Result<Self, DerivationPathError>;
14}
15
16#[derive(Debug, Fail, PartialEq, Eq)]
17pub enum DerivationPathError {
18    #[fail(display = "expected BIP32 path")]
19    ExpectedBIP32Path,
20
21    #[fail(display = "expected BIP44 path")]
22    ExpectedBIP44Path,
23
24    #[fail(display = "expected BIP49 path")]
25    ExpectedBIP49Path,
26
27    #[fail(display = "expected valid Ethereum derivation path")]
28    ExpectedValidEthereumDerivationPath,
29
30    #[fail(display = "expected ZIP32 path")]
31    ExpectedZIP32Path,
32
33    #[fail(display = "expected hardened path")]
34    ExpectedHardenedPath,
35
36    #[fail(display = "expected normal path")]
37    ExpectedNormalPath,
38
39    #[fail(display = "invalid child number: {}", _0)]
40    InvalidChildNumber(u32),
41
42    #[fail(display = "invalid child number format")]
43    InvalidChildNumberFormat,
44
45    #[fail(display = "invalid derivation path: {}", _0)]
46    InvalidDerivationPath(String),
47}
48
49/// Represents a child index for a derivation path
50#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
51pub enum ChildIndex {
52    // A non-hardened index: Normal(n) == n in path notation
53    Normal(u32),
54    // A hardened index: Hardened(n) == n + (1 << 31) == n' in path notation
55    Hardened(u32),
56}
57
58impl ChildIndex {
59    /// Returns [`Normal`] from an index, or errors if the index is not within [0, 2^31 - 1].
60    pub fn normal(index: u32) -> Result<Self, DerivationPathError> {
61        if index & (1 << 31) == 0 {
62            Ok(ChildIndex::Normal(index))
63        } else {
64            Err(DerivationPathError::InvalidChildNumber(index))
65        }
66    }
67
68    /// Returns [`Hardened`] from an index, or errors if the index is not within [0, 2^31 - 1].
69    pub fn hardened(index: u32) -> Result<Self, DerivationPathError> {
70        if index & (1 << 31) == 0 {
71            Ok(ChildIndex::Hardened(index))
72        } else {
73            Err(DerivationPathError::InvalidChildNumber(index))
74        }
75    }
76
77    /// Returns `true` if the child index is a [`Normal`] value.
78    pub fn is_normal(&self) -> bool {
79        !self.is_hardened()
80    }
81
82    /// Returns `true` if the child index is a [`Hardened`] value.
83    pub fn is_hardened(&self) -> bool {
84        match *self {
85            ChildIndex::Hardened(_) => true,
86            ChildIndex::Normal(_) => false,
87        }
88    }
89
90    /// Returns the child index.
91    pub fn to_index(&self) -> u32 {
92        match self {
93            &ChildIndex::Hardened(i) => i + (1 << 31),
94            &ChildIndex::Normal(i) => i,
95        }
96    }
97}
98
99impl From<u32> for ChildIndex {
100    fn from(number: u32) -> Self {
101        if number & (1 << 31) != 0 {
102            ChildIndex::Hardened(number ^ (1 << 31))
103        } else {
104            ChildIndex::Normal(number)
105        }
106    }
107}
108
109impl From<ChildIndex> for u32 {
110    fn from(index: ChildIndex) -> Self {
111        match index {
112            ChildIndex::Normal(number) => number,
113            ChildIndex::Hardened(number) => number | (1 << 31),
114        }
115    }
116}
117
118impl FromStr for ChildIndex {
119    type Err = DerivationPathError;
120
121    fn from_str(inp: &str) -> Result<Self, Self::Err> {
122        Ok(match inp.chars().last().map_or(false, |l| l == '\'' || l == 'h') {
123            true => Self::hardened(
124                inp[0..inp.len() - 1]
125                    .parse()
126                    .map_err(|_| DerivationPathError::InvalidChildNumberFormat)?,
127            )?,
128            false => Self::normal(inp.parse().map_err(|_| DerivationPathError::InvalidChildNumberFormat)?)?,
129        })
130    }
131}
132
133impl fmt::Display for ChildIndex {
134    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
135        match *self {
136            ChildIndex::Hardened(number) => write!(f, "{}'", number),
137            ChildIndex::Normal(number) => write!(f, "{}", number),
138        }
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    mod child_index {
147        use super::*;
148
149        #[test]
150        fn normal() {
151            for i in 0..1 << 31 {
152                assert_eq!(ChildIndex::Normal(i), ChildIndex::normal(i).unwrap());
153            }
154            for i in 1 << 31..std::u32::MAX {
155                assert_eq!(Err(DerivationPathError::InvalidChildNumber(i)), ChildIndex::normal(i));
156            }
157        }
158
159        #[test]
160        fn hardened() {
161            for i in 0..1 << 31 {
162                assert_eq!(ChildIndex::Hardened(i), ChildIndex::hardened(i).unwrap());
163            }
164            for i in 1 << 31..std::u32::MAX {
165                assert_eq!(Err(DerivationPathError::InvalidChildNumber(i)), ChildIndex::hardened(i));
166            }
167        }
168
169        #[test]
170        fn is_normal() {
171            for i in 0..1 << 31 {
172                assert!(ChildIndex::Normal(i).is_normal());
173                assert!(!ChildIndex::Hardened(i).is_normal());
174            }
175        }
176
177        #[test]
178        fn is_hardened() {
179            for i in 0..1 << 31 {
180                assert!(!ChildIndex::Normal(i).is_hardened());
181                assert!(ChildIndex::Hardened(i).is_hardened());
182            }
183        }
184
185        #[test]
186        fn to_index() {
187            for i in 0..1 << 31 {
188                assert_eq!(i, ChildIndex::Normal(i).to_index());
189                assert_eq!(i | (1 << 31), ChildIndex::Hardened(i).to_index());
190            }
191        }
192
193        #[test]
194        fn from() {
195            const THRESHOLD: u32 = 1 << 31;
196            for i in 0..std::u32::MAX {
197                match i < THRESHOLD {
198                    true => assert_eq!(ChildIndex::Normal(i), ChildIndex::from(i)),
199                    false => assert_eq!(ChildIndex::Hardened(i ^ 1 << 31), ChildIndex::from(i)),
200                }
201            }
202        }
203
204        #[test]
205        fn from_str() {
206            for i in (0..1 << 31).step_by(1 << 10) {
207                assert_eq!(ChildIndex::Normal(i), ChildIndex::from_str(&format!("{}", i)).unwrap());
208                assert_eq!(
209                    ChildIndex::Hardened(i),
210                    ChildIndex::from_str(&format!("{}\'", i)).unwrap()
211                );
212                assert_eq!(
213                    ChildIndex::Hardened(i),
214                    ChildIndex::from_str(&format!("{}h", i)).unwrap()
215                );
216            }
217        }
218
219        #[test]
220        fn to_string() {
221            for i in (0..1 << 31).step_by(1 << 10) {
222                assert_eq!(format!("{}", i), ChildIndex::Normal(i).to_string());
223                assert_eq!(format!("{}\'", i), ChildIndex::Hardened(i).to_string());
224            }
225        }
226    }
227}