trezor_crypto/
hd_node.rs

1use crate::bip39::Mnemonic;
2use crate::curve::{Curve, CurveInfoLock, PrivateKey, PublicKey};
3use crate::ecdsa::canonical::{CanonicalFnLock, IsCanonicalFn};
4use crate::ed25519::{Ed25519, Ed25519PrivateKey, Ed25519PublicKey};
5use crate::hasher::{Digest, HashingAlgorithm};
6use crate::signature::{RecoverableSignature, Signature, SIG_LEN};
7use derivation_path::{ChildIndex, DerivationPath};
8use std::marker::PhantomData;
9use std::{fmt, ops};
10
11pub(crate) const HDNODE_PRIVKEY_LEN: usize = 32;
12pub(crate) const HDNODE_PUBKEY_LEN: usize = 33;
13pub const CHAIN_CODE_LEN: usize = 32;
14pub const PRIV_KEY_EXT_LEN: usize = 32;
15
16struct HDNodeRef<'a, C: Curve> {
17    hd_node: &'a sys::HDNode,
18    lock: C::CurveInfoLock,
19}
20
21impl<'a, C: Curve> HDNodeRef<'a, C> {
22    #[inline]
23    fn curve_info(&self) -> &C::CurveInfoLock {
24        &self.lock
25    }
26    #[inline]
27    fn as_ptr(&self) -> *const sys::HDNode {
28        self.hd_node
29    }
30    /// WARNING: Should only be used for functions that fill the public key
31    #[inline]
32    unsafe fn as_mut_ptr(&self) -> *mut sys::HDNode {
33        self.as_ptr() as *mut sys::HDNode
34    }
35}
36
37impl<'a, C: Curve> ops::Deref for HDNodeRef<'a, C> {
38    type Target = sys::HDNode;
39    fn deref(&self) -> &Self::Target {
40        self.hd_node
41    }
42}
43
44struct HDNodeMutRef<'a, C: Curve> {
45    hd_node: &'a mut sys::HDNode,
46    _lock: C::CurveInfoLock,
47}
48
49impl<'a, C: Curve> HDNodeMutRef<'a, C> {
50    #[inline]
51    fn as_ptr(&mut self) -> *mut sys::HDNode {
52        self.hd_node
53    }
54}
55
56impl<'a, C: Curve> ops::Deref for HDNodeMutRef<'a, C> {
57    type Target = sys::HDNode;
58    fn deref(&self) -> &Self::Target {
59        self.hd_node
60    }
61}
62
63impl<'a, C: Curve> ops::DerefMut for HDNodeMutRef<'a, C> {
64    fn deref_mut(&mut self) -> &mut Self::Target {
65        self.hd_node
66    }
67}
68
69pub struct HDNode<C: Curve> {
70    hd_node: sys::HDNode,
71    curve: PhantomData<C>,
72}
73
74impl<C: Curve> HDNode<C> {
75    unsafe fn zeroed() -> Self {
76        let hd_node = std::mem::zeroed();
77        Self {
78            hd_node,
79            curve: PhantomData,
80        }
81    }
82    #[inline]
83    pub fn depth(&self) -> u8 {
84        self.hd_node.depth as u8
85    }
86
87    #[inline]
88    pub fn child_index(&self) -> ChildIndex {
89        ChildIndex::from_bits(self.hd_node.child_num)
90    }
91
92    #[inline]
93    pub fn chain_code(&self) -> [u8; CHAIN_CODE_LEN] {
94        self.hd_node.chain_code
95    }
96
97    #[inline]
98    pub fn public_key(&self) -> C::PublicKey {
99        C::PublicKey::from_bytes_unchecked(self.hd_node.public_key)
100    }
101
102    #[inline]
103    pub fn fingerprint(&self) -> u32 {
104        unsafe {
105            let hd_node = self.borrow();
106            sys::hdnode_fingerprint(hd_node.as_mut_ptr())
107        }
108    }
109
110    unsafe fn borrow(&self) -> HDNodeRef<C> {
111        HDNodeRef {
112            hd_node: &self.hd_node,
113            lock: C::curve_info_lock(),
114        }
115    }
116
117    unsafe fn borrow_mut(&mut self) -> HDNodeMutRef<C> {
118        HDNodeMutRef {
119            hd_node: &mut self.hd_node,
120            _lock: C::curve_info_lock(),
121        }
122    }
123}
124
125impl<C: Curve> Clone for HDNode<C> {
126    fn clone(&self) -> Self {
127        Self {
128            hd_node: self.hd_node.clone(),
129            curve: PhantomData,
130        }
131    }
132}
133
134pub struct ExtendedPrivateKey<C: Curve>(HDNode<C>);
135
136impl<C: Curve> ExtendedPrivateKey<C> {
137    pub fn new(
138        private_key: C::PrivateKey,
139        chain_code: [u8; CHAIN_CODE_LEN],
140        child_index: ChildIndex,
141        depth: u8,
142    ) -> Option<Self> {
143        let private_key_bytes = private_key.to_bytes();
144        let (inner, res) = unsafe {
145            let mut inner = HDNode::zeroed();
146            let mut hd_node = inner.borrow_mut();
147            let res = sys::hdnode_from_xprv(
148                depth as u32,
149                child_index.to_bits(),
150                chain_code.as_ptr(),
151                private_key_bytes.as_ptr(),
152                C::name_ptr(),
153                hd_node.as_ptr(),
154            );
155            (inner, res)
156        };
157        if res == 1 {
158            let mut this = Self(inner);
159            this.fill_public_key();
160            Some(this)
161        } else {
162            None
163        }
164    }
165
166    pub fn from_seed(seed: &[u8]) -> Option<Self> {
167        let mut this: Self;
168        let res = unsafe {
169            this = std::mem::zeroed();
170            let mut hd_node = this.borrow_mut();
171            if C::is_cardano() {
172                sys::hdnode_from_seed_cardano(seed.as_ptr(), seed.len() as i32, hd_node.as_ptr())
173            } else {
174                sys::hdnode_from_seed(
175                    seed.as_ptr(),
176                    seed.len() as i32,
177                    C::name_ptr(),
178                    hd_node.as_ptr(),
179                )
180            }
181        };
182        if res == 1 {
183            this.fill_public_key();
184            Some(this)
185        } else {
186            None
187        }
188    }
189
190    pub fn from_mnemonic(mnemonic: &Mnemonic, password: &str) -> Option<Self> {
191        let mut this: Self;
192        if C::is_cardano() {
193            let res = unsafe {
194                this = std::mem::zeroed();
195                let mut hd_node = this.borrow_mut();
196                let pass = password.as_bytes();
197                let entropy = mnemonic.entropy_cardano();
198                sys::hdnode_from_entropy_cardano_icarus(
199                    pass.as_ptr(),
200                    pass.len() as i32,
201                    entropy.as_ptr(),
202                    entropy.len() as i32,
203                    hd_node.as_ptr(),
204                )
205            };
206            if res == 1 {
207                this.fill_public_key();
208                Some(this)
209            } else {
210                None
211            }
212        } else {
213            Self::from_seed(&mnemonic.seed(password))
214        }
215    }
216
217    pub fn extend_public_key(&self) -> ExtendedPublicKey<C> {
218        let mut inner = self.0.clone();
219        inner.hd_node.private_key.fill(0);
220        ExtendedPublicKey(inner)
221    }
222
223    #[inline]
224    pub fn private_key_extension(&self) -> [u8; PRIV_KEY_EXT_LEN] {
225        self.hd_node.private_key_extension
226    }
227
228    fn fill_public_key(&mut self) {
229        unsafe {
230            let mut hd_node = self.borrow_mut();
231            sys::hdnode_fill_public_key(hd_node.as_ptr())
232        }
233    }
234
235    pub fn derive_next(&mut self, index: ChildIndex) {
236        unsafe {
237            let mut hd_node = self.borrow_mut();
238            if C::is_cardano() {
239                sys::hdnode_private_ckd_cardano(hd_node.as_ptr(), index.to_bits());
240            } else {
241                sys::hdnode_private_ckd(hd_node.as_ptr(), index.to_bits());
242            }
243        }
244        self.fill_public_key();
245    }
246
247    pub fn derive(&mut self, path: &DerivationPath) {
248        for index in path {
249            self.derive_next(*index);
250        }
251    }
252
253    #[inline]
254    pub fn private_key(&self) -> C::PrivateKey {
255        C::PrivateKey::from_bytes_unchecked(self.hd_node.private_key)
256    }
257
258    pub fn sign<H: HashingAlgorithm, D: AsRef<[u8]>>(
259        &self,
260        data: D,
261        is_canonical: Option<IsCanonicalFn>,
262    ) -> Option<RecoverableSignature<C>> {
263        let data = data.as_ref();
264        let hasher_type = H::hasher_type();
265        let mut sig = [0; SIG_LEN];
266        let mut by = 0;
267        let res = unsafe {
268            let hd_node = self.borrow();
269            let curve_lock = hd_node.curve_info().curve();
270            sys::hdnode_sign(
271                hd_node.as_mut_ptr(),
272                data.as_ptr(),
273                data.len() as u32,
274                hasher_type,
275                sig.as_mut_ptr(),
276                &mut by,
277                curve_lock.is_canonical_fn(is_canonical),
278            )
279        };
280        if res == 0 {
281            Some(RecoverableSignature::new(Signature::from_bytes(sig), by))
282        } else {
283            None
284        }
285    }
286    pub fn sign_digest(
287        &self,
288        digest: &Digest,
289        is_canonical: Option<IsCanonicalFn>,
290    ) -> Option<RecoverableSignature<C>> {
291        let mut sig = [0; SIG_LEN];
292        let mut by = 0;
293        let res = unsafe {
294            let hd_node = self.borrow();
295            let curve_lock = hd_node.curve_info().curve();
296            sys::hdnode_sign_digest(
297                hd_node.as_mut_ptr(),
298                digest.as_ref().as_ptr(),
299                sig.as_mut_ptr(),
300                &mut by,
301                curve_lock.is_canonical_fn(is_canonical),
302            )
303        };
304        if res == 0 {
305            Some(RecoverableSignature::new(Signature::from_bytes(sig), by))
306        } else {
307            None
308        }
309    }
310}
311
312impl<C> ExtendedPrivateKey<C>
313where
314    C: Curve<PublicKey = Ed25519PublicKey, PrivateKey = Ed25519PrivateKey>,
315{
316    #[inline]
317    pub fn private_key_ext(&self) -> Ed25519PrivateKey {
318        Ed25519PrivateKey::from_bytes_unchecked(self.hd_node.private_key_extension)
319    }
320    #[inline]
321    pub fn public_key_ext(&self) -> Ed25519PublicKey {
322        self.private_key().public_key_ext(&self.private_key_ext())
323    }
324    #[inline]
325    pub fn sign_ext<D: AsRef<[u8]>>(&self, data: D) -> Signature<Ed25519> {
326        self.private_key().sign_ext(&self.private_key_ext(), data)
327    }
328}
329
330impl<C: Curve> Clone for ExtendedPrivateKey<C> {
331    fn clone(&self) -> Self {
332        Self(self.0.clone())
333    }
334}
335
336impl<C: Curve> ops::Deref for ExtendedPrivateKey<C> {
337    type Target = HDNode<C>;
338    fn deref(&self) -> &Self::Target {
339        &self.0
340    }
341}
342
343impl<C: Curve> ops::DerefMut for ExtendedPrivateKey<C> {
344    fn deref_mut(&mut self) -> &mut Self::Target {
345        &mut self.0
346    }
347}
348
349pub struct ExtendedPublicKey<C: Curve>(HDNode<C>);
350
351impl<C: Curve> ExtendedPublicKey<C> {
352    pub fn new(
353        public_key: C::PublicKey,
354        chain_code: [u8; CHAIN_CODE_LEN],
355        child_index: ChildIndex,
356        depth: u8,
357    ) -> Option<Self> {
358        let public_key_bytes = public_key.to_bytes();
359        let (inner, res) = unsafe {
360            let mut inner = HDNode::zeroed();
361            let mut hd_node = inner.borrow_mut();
362            let res = sys::hdnode_from_xpub(
363                depth as u32,
364                child_index.to_bits(),
365                chain_code.as_ptr(),
366                public_key_bytes.as_ptr(),
367                C::name_ptr(),
368                hd_node.as_ptr(),
369            );
370            (inner, res)
371        };
372        if res == 1 {
373            Some(Self(inner))
374        } else {
375            None
376        }
377    }
378
379    pub fn derive_next(&mut self, index: ChildIndex) {
380        unsafe {
381            let mut hd_node = self.borrow_mut();
382            sys::hdnode_public_ckd(hd_node.as_ptr(), index.to_bits());
383        }
384    }
385
386    pub fn derive(&mut self, path: &DerivationPath) {
387        for index in path {
388            self.derive_next(*index);
389        }
390    }
391}
392
393impl<C: Curve> Clone for ExtendedPublicKey<C> {
394    fn clone(&self) -> Self {
395        Self(self.0.clone())
396    }
397}
398
399impl<C: Curve> ops::Deref for ExtendedPublicKey<C> {
400    type Target = HDNode<C>;
401    fn deref(&self) -> &Self::Target {
402        &self.0
403    }
404}
405
406impl<C: Curve> ops::DerefMut for ExtendedPublicKey<C> {
407    fn deref_mut(&mut self) -> &mut Self::Target {
408        &mut self.0
409    }
410}
411
412impl<C> fmt::Debug for HDNode<C>
413where
414    C: Curve,
415{
416    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
417        f.debug_struct("HDNode")
418            .field("chain_code", &hex::encode(&self.chain_code()))
419            .field("child_index", &self.child_index())
420            .field("depth", &self.depth())
421            .field("fingerprint", &self.fingerprint())
422            .field("public_key", &hex::encode(&self.public_key().serialize()))
423            .finish()
424    }
425}
426
427impl<C> fmt::Debug for ExtendedPublicKey<C>
428where
429    C: Curve,
430{
431    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
432        f.debug_tuple("ExtendedPublicKey").field(&self.0).finish()
433    }
434}
435
436impl<C> fmt::Debug for ExtendedPrivateKey<C>
437where
438    C: Curve,
439{
440    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
441        f.debug_tuple("ExtendedPrivateKey").field(&self.0).finish()
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448    use crate::bip39::Mnemonic;
449    use crate::ecdsa::Secp256k1;
450    use crate::ed25519::Ed25519Cardano;
451    use crate::hasher::Sha2;
452
453    fn dignity_mnemonic() -> Mnemonic {
454        Mnemonic::from_phrase("dignity pass list indicate nasty swamp pool script soccer toe leaf photo multiply desk host tomato cradle drill spread actor shine dismiss champion exotic").unwrap()
455    }
456
457    #[test]
458    fn derive_next() {
459        let seed = hex::decode("000102030405060708090a0b0c0d0e0f").unwrap();
460        let mut hd_node = ExtendedPrivateKey::<Secp256k1>::from_seed(&seed).unwrap();
461        assert_eq!(hd_node.depth(), 0);
462        assert_eq!(hd_node.fingerprint(), 0x3442193e);
463        assert_eq!(
464            "e8f32e723decf4051aefac8e2c93c9c5b214313817cdb01a1494b917c8436b35",
465            hex::encode(hd_node.private_key().to_bytes())
466        );
467        assert_eq!(
468            "873dff81c02f525623fd1fe5167eac3a55a049de3d314bb42ee227ffed37d508",
469            hex::encode(hd_node.chain_code())
470        );
471        assert_eq!(
472            "0339a36013301597daef41fbe593a02cc513d0b55527ec2df1050e2e8ff49c85c2",
473            hex::encode(hd_node.public_key().serialize())
474        );
475        hd_node.derive_next(ChildIndex::Hardened(0));
476        assert_eq!(hd_node.depth(), 1);
477        assert_eq!(hd_node.fingerprint(), 0x5c1bd648);
478        assert_eq!(
479            "edb2e14f9ee77d26dd93b4ecede8d16ed408ce149b6cd80b0715a2d911a0afea",
480            hex::encode(hd_node.private_key().to_bytes())
481        );
482        assert_eq!(
483            "47fdacbd0f1097043b78c63c20c34ef4ed9a111d980047ad16282c7ae6236141",
484            hex::encode(hd_node.chain_code())
485        );
486        assert_eq!(
487            "035a784662a4a20a65bf6aab9ae98a6c068a81c52e4b032c0fb5400c706cfccc56",
488            hex::encode(hd_node.public_key().serialize())
489        );
490    }
491
492    #[test]
493    fn derive() {
494        let seed = hex::decode("fffcf9f6f3f0edeae7e4e1dedbd8d5d2cfccc9c6c3c0bdbab7b4b1aeaba8a5a29f9c999693908d8a8784817e7b7875726f6c696663605d5a5754514e4b484542").unwrap();
495        let mut hd_node = ExtendedPrivateKey::<Secp256k1>::from_seed(&seed).unwrap();
496        assert_eq!(hd_node.depth(), 0);
497        assert_eq!(hd_node.fingerprint(), 0xbd16bee5);
498        assert_eq!(
499            "4b03d6fc340455b363f51020ad3ecca4f0850280cf436c70c727923f6db46c3e",
500            hex::encode(hd_node.private_key().to_bytes())
501        );
502        assert_eq!(
503            "60499f801b896d83179a4374aeb7822aaeaceaa0db1f85ee3e904c4defbd9689",
504            hex::encode(hd_node.chain_code())
505        );
506        assert_eq!(
507            "03cbcaa9c98c877a26977d00825c956a238e8dddfbd322cce4f74b0b5bd6ace4a7",
508            hex::encode(hd_node.public_key().serialize())
509        );
510        hd_node.derive(&"m/0/2147483647'/1/2147483646'/2".parse().unwrap());
511        assert_eq!(hd_node.depth(), 5);
512        assert_eq!(hd_node.fingerprint(), 0x26132fdb);
513        assert_eq!(
514            "bb7d39bdb83ecf58f2fd82b6d918341cbef428661ef01ab97c28a4842125ac23",
515            hex::encode(hd_node.private_key().to_bytes())
516        );
517        assert_eq!(
518            "9452b549be8cea3ecb7a84bec10dcfd94afe4d129ebfd3b3cb58eedf394ed271",
519            hex::encode(hd_node.chain_code())
520        );
521        assert_eq!(
522            "024d902e1a2fc7a8755ab5b694c575fce742c48d9ff192e63df5193e4c7afe1f9c",
523            hex::encode(hd_node.public_key().serialize())
524        );
525    }
526
527    #[test]
528    fn from_xprv() {
529        let seed = hex::decode("fffcf9f6f3f0edeae7e4e1dedbd8d5d2cfccc9c6c3c0bdbab7b4b1aeaba8a5a29f9c999693908d8a8784817e7b7875726f6c696663605d5a5754514e4b484542").unwrap();
530        let hd_node = ExtendedPrivateKey::<Secp256k1>::from_seed(&seed).unwrap();
531        let hd_node2 = ExtendedPrivateKey::<Secp256k1>::new(
532            hd_node.private_key(),
533            hd_node.chain_code(),
534            hd_node.child_index(),
535            hd_node.depth(),
536        )
537        .unwrap();
538        assert_eq!(
539            hd_node.private_key().to_bytes(),
540            hd_node2.private_key().to_bytes()
541        );
542        assert_eq!(hd_node.public_key(), hd_node2.public_key());
543        assert_eq!(hd_node.chain_code(), hd_node2.chain_code());
544        assert_eq!(hd_node.depth(), hd_node2.depth());
545        assert_eq!(
546            hd_node.private_key_extension(),
547            hd_node2.private_key_extension()
548        );
549        assert_eq!(hd_node.child_index(), hd_node2.child_index());
550    }
551
552    #[test]
553    fn from_xpub() {
554        let seed = hex::decode("fffcf9f6f3f0edeae7e4e1dedbd8d5d2cfccc9c6c3c0bdbab7b4b1aeaba8a5a29f9c999693908d8a8784817e7b7875726f6c696663605d5a5754514e4b484542").unwrap();
555        let hd_node = ExtendedPrivateKey::<Secp256k1>::from_seed(&seed)
556            .unwrap()
557            .extend_public_key();
558        let hd_node2 = ExtendedPublicKey::<Secp256k1>::new(
559            hd_node.public_key(),
560            hd_node.chain_code(),
561            hd_node.child_index(),
562            hd_node.depth(),
563        )
564        .unwrap();
565        assert_eq!(hd_node.public_key(), hd_node2.public_key());
566        assert_eq!(hd_node.chain_code(), hd_node2.chain_code());
567        assert_eq!(hd_node.depth(), hd_node2.depth());
568        assert_eq!(hd_node.child_index(), hd_node2.child_index());
569    }
570
571    #[test]
572    fn sign() {
573        let seed = hex::decode("fffcf9f6f3f0edeae7e4e1dedbd8d5d2cfccc9c6c3c0bdbab7b4b1aeaba8a5a29f9c999693908d8a8784817e7b7875726f6c696663605d5a5754514e4b484542").unwrap();
574        let hd_node = ExtendedPrivateKey::<Secp256k1>::from_seed(&seed).unwrap();
575        let message = b"hello";
576        let signature = hd_node.sign::<Sha2, _>(message, None).unwrap();
577        let signature2 = hd_node
578            .private_key()
579            .sign::<Sha2, _>(message, None)
580            .unwrap();
581        assert_eq!(signature, signature2);
582    }
583
584    #[test]
585    fn dignity_secp256k1() {
586        let mut hd_node =
587            ExtendedPrivateKey::<Secp256k1>::from_mnemonic(&dignity_mnemonic(), "").unwrap();
588        hd_node.derive(&"m/1852'/1815'/0'".parse().unwrap());
589        assert_eq!(hd_node.depth(), 3);
590        assert_eq!(
591            &hd_node.chain_code(),
592            hex::decode("0d27e09175aed7737fabe8dc833d034f32750e01179dfcf26c74bafada708d38")
593                .unwrap()
594                .as_slice()
595        );
596        assert_eq!(
597            &hd_node.public_key().serialize(),
598            hex::decode("02ee2fdb748f9bc8648372cc79cbe543eed0401528a8ab91966f5135e55aac2d99")
599                .unwrap()
600                .as_slice()
601        );
602    }
603
604    #[test]
605    fn dignity_cardano() {
606        let mut hd_node =
607            ExtendedPrivateKey::<Ed25519Cardano>::from_mnemonic(&dignity_mnemonic(), "").unwrap();
608        assert_eq!(hd_node.depth(), 0);
609        assert_eq!(
610            &hd_node.chain_code(),
611            hex::decode("350df93ad0ebdbdd42d719badfec2670efe013902bdc05c838774d7118fb9ac8")
612                .unwrap()
613                .as_slice()
614        );
615        assert_eq!(
616            &hd_node.public_key().serialize(),
617            hex::decode("41fe8c524e2e1b4c67ea2b0030f121515085ffa4861e2816dbaaeaf93428eb63")
618                .unwrap()
619                .as_slice()
620        );
621        hd_node.derive(&"m/1852'/1815'/0'".parse().unwrap());
622        assert_eq!(hd_node.depth(), 3);
623        assert_eq!(
624            &hd_node.chain_code(),
625            hex::decode("415b7d92ecc8539cac4fcc23f2f243a0cfc59125129b9fb297e05bcc8625a51e")
626                .unwrap()
627                .as_slice()
628        );
629        assert_eq!(
630            &hd_node.public_key().serialize(),
631            hex::decode("80609213e0e94b2e49b03996fd57262fed51f34108d6167a69df6938a3435cb3")
632                .unwrap()
633                .as_slice()
634        );
635    }
636}