1use {
13 core::{iter::IntoIterator, slice::Iter},
14 derivation_path::{ChildIndex, DerivationPath as DerivationPathInner},
15 std::{
16 convert::{Infallible, TryFrom},
17 fmt,
18 str::FromStr,
19 },
20 thiserror::Error,
21 uriparse::URIReference,
22};
23
24const ACCOUNT_INDEX: usize = 2;
25const CHANGE_INDEX: usize = 3;
26
27#[derive(Error, Debug, Clone, PartialEq, Eq)]
29pub enum DerivationPathError {
30 #[error("invalid derivation path: {0}")]
31 InvalidDerivationPath(String),
32 #[error("infallible")]
33 Infallible,
34}
35
36impl From<Infallible> for DerivationPathError {
37 fn from(_: Infallible) -> Self {
38 Self::Infallible
39 }
40}
41
42#[derive(Clone, PartialEq, Eq)]
43pub struct DerivationPath(DerivationPathInner);
44
45impl Default for DerivationPath {
46 fn default() -> Self {
47 Self::new_bip44(None, None)
48 }
49}
50
51impl TryFrom<&str> for DerivationPath {
52 type Error = DerivationPathError;
53 fn try_from(s: &str) -> Result<Self, Self::Error> {
54 Self::from_key_str(s)
55 }
56}
57
58impl AsRef<[ChildIndex]> for DerivationPath {
59 fn as_ref(&self) -> &[ChildIndex] {
60 self.0.as_ref()
61 }
62}
63
64impl DerivationPath {
65 fn new<P: Into<Box<[ChildIndex]>>>(path: P) -> Self {
66 Self(DerivationPathInner::new(path))
67 }
68
69 pub fn from_key_str(path: &str) -> Result<Self, DerivationPathError> {
70 Self::from_key_str_with_coin(path, Solana)
71 }
72
73 fn from_key_str_with_coin<T: Bip44>(path: &str, coin: T) -> Result<Self, DerivationPathError> {
74 let master_path = if path == "m" {
75 path.to_string()
76 } else {
77 format!("m/{path}")
78 };
79 let extend = DerivationPathInner::from_str(&master_path)
80 .map_err(|err| DerivationPathError::InvalidDerivationPath(err.to_string()))?;
81 let mut extend = extend.into_iter();
82 let account = extend.next().map(|index| index.to_u32());
83 let change = extend.next().map(|index| index.to_u32());
84 if extend.next().is_some() {
85 return Err(DerivationPathError::InvalidDerivationPath(format!(
86 "key path `{path}` too deep, only <account>/<change> supported"
87 )));
88 }
89 Ok(Self::new_bip44_with_coin(coin, account, change))
90 }
91
92 pub fn from_absolute_path_str(path: &str) -> Result<Self, DerivationPathError> {
93 let inner = DerivationPath::_from_absolute_path_insecure_str(path)?
94 .into_iter()
95 .map(|c| ChildIndex::Hardened(c.to_u32()))
96 .collect::<Vec<_>>();
97 Ok(Self(DerivationPathInner::new(inner)))
98 }
99
100 fn _from_absolute_path_insecure_str(path: &str) -> Result<Self, DerivationPathError> {
101 Ok(Self(DerivationPathInner::from_str(path).map_err(
102 |err| DerivationPathError::InvalidDerivationPath(err.to_string()),
103 )?))
104 }
105
106 pub fn new_bip44(account: Option<u32>, change: Option<u32>) -> Self {
107 Self::new_bip44_with_coin(Solana, account, change)
108 }
109
110 fn new_bip44_with_coin<T: Bip44>(coin: T, account: Option<u32>, change: Option<u32>) -> Self {
111 let mut indexes = coin.base_indexes();
112 if let Some(account) = account {
113 indexes.push(ChildIndex::Hardened(account));
114 if let Some(change) = change {
115 indexes.push(ChildIndex::Hardened(change));
116 }
117 }
118 Self::new(indexes)
119 }
120
121 pub fn account(&self) -> Option<&ChildIndex> {
122 self.0.path().get(ACCOUNT_INDEX)
123 }
124
125 pub fn change(&self) -> Option<&ChildIndex> {
126 self.0.path().get(CHANGE_INDEX)
127 }
128
129 pub fn path(&self) -> &[ChildIndex] {
130 self.0.path()
131 }
132
133 pub fn get_query(&self) -> String {
135 if let Some(account) = &self.account() {
136 if let Some(change) = &self.change() {
137 format!("?key={account}/{change}")
138 } else {
139 format!("?key={account}")
140 }
141 } else {
142 "".to_string()
143 }
144 }
145
146 pub fn from_uri_key_query(uri: &URIReference<'_>) -> Result<Option<Self>, DerivationPathError> {
147 Self::from_uri(uri, true)
148 }
149
150 pub fn from_uri_any_query(uri: &URIReference<'_>) -> Result<Option<Self>, DerivationPathError> {
151 Self::from_uri(uri, false)
152 }
153
154 fn from_uri(
155 uri: &URIReference<'_>,
156 key_only: bool,
157 ) -> Result<Option<Self>, DerivationPathError> {
158 if let Some(query) = uri.query() {
159 let query_str = query.as_str();
160 if query_str.is_empty() {
161 return Ok(None);
162 }
163 let query = qstring::QString::from(query_str);
164 if query.len() > 1 {
165 return Err(DerivationPathError::InvalidDerivationPath(
166 "invalid query string, extra fields not supported".to_string(),
167 ));
168 }
169 let key = query.get(QueryKey::Key.as_ref());
170 if let Some(key) = key {
171 return Self::from_key_str(key).map(Some);
174 }
175 if key_only {
176 return Err(DerivationPathError::InvalidDerivationPath(format!(
177 "invalid query string `{query_str}`, only `key` supported",
178 )));
179 }
180 let full_path = query.get(QueryKey::FullPath.as_ref());
181 if let Some(full_path) = full_path {
182 return Self::from_absolute_path_str(full_path).map(Some);
183 }
184 Err(DerivationPathError::InvalidDerivationPath(format!(
185 "invalid query string `{query_str}`, only `key` and `full-path` supported",
186 )))
187 } else {
188 Ok(None)
189 }
190 }
191}
192
193impl fmt::Debug for DerivationPath {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 write!(f, "m")?;
196 for index in self.0.path() {
197 write!(f, "/{index}")?;
198 }
199 Ok(())
200 }
201}
202
203impl<'a> IntoIterator for &'a DerivationPath {
204 type IntoIter = Iter<'a, ChildIndex>;
205 type Item = &'a ChildIndex;
206 fn into_iter(self) -> Self::IntoIter {
207 self.0.into_iter()
208 }
209}
210
211const QUERY_KEY_FULL_PATH: &str = "full-path";
212const QUERY_KEY_KEY: &str = "key";
213
214#[derive(Clone, Debug, Error, PartialEq, Eq)]
215#[error("invalid query key `{0}`")]
216struct QueryKeyError(String);
217
218enum QueryKey {
219 FullPath,
220 Key,
221}
222
223impl FromStr for QueryKey {
224 type Err = QueryKeyError;
225 fn from_str(s: &str) -> Result<Self, Self::Err> {
226 let lowercase = s.to_ascii_lowercase();
227 match lowercase.as_str() {
228 QUERY_KEY_FULL_PATH => Ok(Self::FullPath),
229 QUERY_KEY_KEY => Ok(Self::Key),
230 _ => Err(QueryKeyError(s.to_string())),
231 }
232 }
233}
234
235impl AsRef<str> for QueryKey {
236 fn as_ref(&self) -> &str {
237 match self {
238 Self::FullPath => QUERY_KEY_FULL_PATH,
239 Self::Key => QUERY_KEY_KEY,
240 }
241 }
242}
243
244impl std::fmt::Display for QueryKey {
245 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
246 let s: &str = self.as_ref();
247 write!(f, "{s}")
248 }
249}
250
251trait Bip44 {
252 const PURPOSE: u32 = 44;
253 const COIN: u32;
254
255 fn base_indexes(&self) -> Vec<ChildIndex> {
256 vec![
257 ChildIndex::Hardened(Self::PURPOSE),
258 ChildIndex::Hardened(Self::COIN),
259 ]
260 }
261}
262
263struct Solana;
264
265impl Bip44 for Solana {
266 const COIN: u32 = 501;
267}
268
269#[cfg(test)]
270mod tests {
271 use {super::*, uriparse::URIReferenceBuilder};
272
273 struct TestCoin;
274 impl Bip44 for TestCoin {
275 const COIN: u32 = 999;
276 }
277
278 #[test]
279 fn test_from_key_str() {
280 let s = "1/2";
281 assert_eq!(
282 DerivationPath::from_key_str_with_coin(s, TestCoin).unwrap(),
283 DerivationPath::new_bip44_with_coin(TestCoin, Some(1), Some(2))
284 );
285 let s = "1'/2'";
286 assert_eq!(
287 DerivationPath::from_key_str_with_coin(s, TestCoin).unwrap(),
288 DerivationPath::new_bip44_with_coin(TestCoin, Some(1), Some(2))
289 );
290 let s = "1\'/2\'";
291 assert_eq!(
292 DerivationPath::from_key_str_with_coin(s, TestCoin).unwrap(),
293 DerivationPath::new_bip44_with_coin(TestCoin, Some(1), Some(2))
294 );
295 let s = "1";
296 assert_eq!(
297 DerivationPath::from_key_str_with_coin(s, TestCoin).unwrap(),
298 DerivationPath::new_bip44_with_coin(TestCoin, Some(1), None)
299 );
300 let s = "1'";
301 assert_eq!(
302 DerivationPath::from_key_str_with_coin(s, TestCoin).unwrap(),
303 DerivationPath::new_bip44_with_coin(TestCoin, Some(1), None)
304 );
305 let s = "1\'";
306 assert_eq!(
307 DerivationPath::from_key_str_with_coin(s, TestCoin).unwrap(),
308 DerivationPath::new_bip44_with_coin(TestCoin, Some(1), None)
309 );
310
311 assert!(DerivationPath::from_key_str_with_coin("1/2/3", TestCoin).is_err());
312 assert!(DerivationPath::from_key_str_with_coin("other", TestCoin).is_err());
313 assert!(DerivationPath::from_key_str_with_coin("1o", TestCoin).is_err());
314 }
315
316 #[test]
317 fn test_from_absolute_path_str() {
318 let s = "m/44/501";
319 assert_eq!(
320 DerivationPath::from_absolute_path_str(s).unwrap(),
321 DerivationPath::default()
322 );
323 let s = "m/44'/501'";
324 assert_eq!(
325 DerivationPath::from_absolute_path_str(s).unwrap(),
326 DerivationPath::default()
327 );
328 let s = "m/44'/501'/1/2";
329 assert_eq!(
330 DerivationPath::from_absolute_path_str(s).unwrap(),
331 DerivationPath::new_bip44(Some(1), Some(2))
332 );
333 let s = "m/44'/501'/1'/2'";
334 assert_eq!(
335 DerivationPath::from_absolute_path_str(s).unwrap(),
336 DerivationPath::new_bip44(Some(1), Some(2))
337 );
338
339 let s = "m/44'/999'/1/2";
341 assert_eq!(
342 DerivationPath::from_absolute_path_str(s).unwrap(),
343 DerivationPath::new_bip44_with_coin(TestCoin, Some(1), Some(2))
344 );
345 let s = "m/44'/999'/1'/2'";
346 assert_eq!(
347 DerivationPath::from_absolute_path_str(s).unwrap(),
348 DerivationPath::new_bip44_with_coin(TestCoin, Some(1), Some(2))
349 );
350
351 let s = "m/501'/0'/0/0";
353 assert_eq!(
354 DerivationPath::from_absolute_path_str(s).unwrap(),
355 DerivationPath::new(vec![
356 ChildIndex::Hardened(501),
357 ChildIndex::Hardened(0),
358 ChildIndex::Hardened(0),
359 ChildIndex::Hardened(0),
360 ])
361 );
362 let s = "m/501'/0'/0'/0'";
363 assert_eq!(
364 DerivationPath::from_absolute_path_str(s).unwrap(),
365 DerivationPath::new(vec![
366 ChildIndex::Hardened(501),
367 ChildIndex::Hardened(0),
368 ChildIndex::Hardened(0),
369 ChildIndex::Hardened(0),
370 ])
371 );
372 }
373
374 #[test]
375 fn test_from_uri() {
376 let derivation_path = DerivationPath::new_bip44(Some(0), Some(0));
377
378 let mut builder = URIReferenceBuilder::new();
380 builder
381 .try_scheme(Some("test"))
382 .unwrap()
383 .try_authority(Some("path"))
384 .unwrap()
385 .try_path("")
386 .unwrap()
387 .try_query(Some("key=0/0"))
388 .unwrap();
389 let uri = builder.build().unwrap();
390 assert_eq!(
391 DerivationPath::from_uri(&uri, true).unwrap(),
392 Some(derivation_path.clone())
393 );
394
395 let mut builder = URIReferenceBuilder::new();
397 builder
398 .try_scheme(Some("test"))
399 .unwrap()
400 .try_authority(Some("path"))
401 .unwrap()
402 .try_path("")
403 .unwrap()
404 .try_query(Some("key=0'/0'"))
405 .unwrap();
406 let uri = builder.build().unwrap();
407 assert_eq!(
408 DerivationPath::from_uri(&uri, true).unwrap(),
409 Some(derivation_path.clone())
410 );
411
412 let mut builder = URIReferenceBuilder::new();
414 builder
415 .try_scheme(Some("test"))
416 .unwrap()
417 .try_authority(Some("path"))
418 .unwrap()
419 .try_path("")
420 .unwrap()
421 .try_query(Some("key=0\'/0\'"))
422 .unwrap();
423 let uri = builder.build().unwrap();
424 assert_eq!(
425 DerivationPath::from_uri(&uri, true).unwrap(),
426 Some(derivation_path)
427 );
428
429 let mut builder = URIReferenceBuilder::new();
431 builder
432 .try_scheme(Some("test"))
433 .unwrap()
434 .try_authority(Some("path"))
435 .unwrap()
436 .try_path("")
437 .unwrap()
438 .try_query(Some("key=m"))
439 .unwrap();
440 let uri = builder.build().unwrap();
441 assert_eq!(
442 DerivationPath::from_uri(&uri, true).unwrap(),
443 Some(DerivationPath::new_bip44(None, None))
444 );
445
446 let mut builder = URIReferenceBuilder::new();
448 builder
449 .try_scheme(Some("test"))
450 .unwrap()
451 .try_authority(Some("path"))
452 .unwrap()
453 .try_path("")
454 .unwrap();
455 let uri = builder.build().unwrap();
456 assert_eq!(DerivationPath::from_uri(&uri, true).unwrap(), None);
457
458 let mut builder = URIReferenceBuilder::new();
460 builder
461 .try_scheme(Some("test"))
462 .unwrap()
463 .try_authority(Some("path"))
464 .unwrap()
465 .try_path("")
466 .unwrap()
467 .try_query(Some(""))
468 .unwrap();
469 let uri = builder.build().unwrap();
470 assert_eq!(DerivationPath::from_uri(&uri, true).unwrap(), None);
471
472 let mut builder = URIReferenceBuilder::new();
474 builder
475 .try_scheme(Some("test"))
476 .unwrap()
477 .try_authority(Some("path"))
478 .unwrap()
479 .try_path("")
480 .unwrap()
481 .try_query(Some("key=0/0/0"))
482 .unwrap();
483 let uri = builder.build().unwrap();
484 assert!(matches!(
485 DerivationPath::from_uri(&uri, true),
486 Err(DerivationPathError::InvalidDerivationPath(_))
487 ));
488
489 let mut builder = URIReferenceBuilder::new();
491 builder
492 .try_scheme(Some("test"))
493 .unwrap()
494 .try_authority(Some("path"))
495 .unwrap()
496 .try_path("")
497 .unwrap()
498 .try_query(Some("key=0/0&bad-key=0/0"))
499 .unwrap();
500 let uri = builder.build().unwrap();
501 assert!(matches!(
502 DerivationPath::from_uri(&uri, true),
503 Err(DerivationPathError::InvalidDerivationPath(_))
504 ));
505
506 let mut builder = URIReferenceBuilder::new();
508 builder
509 .try_scheme(Some("test"))
510 .unwrap()
511 .try_authority(Some("path"))
512 .unwrap()
513 .try_path("")
514 .unwrap()
515 .try_query(Some("bad-key=0/0"))
516 .unwrap();
517 let uri = builder.build().unwrap();
518 assert!(matches!(
519 DerivationPath::from_uri(&uri, true),
520 Err(DerivationPathError::InvalidDerivationPath(_))
521 ));
522
523 let mut builder = URIReferenceBuilder::new();
525 builder
526 .try_scheme(Some("test"))
527 .unwrap()
528 .try_authority(Some("path"))
529 .unwrap()
530 .try_path("")
531 .unwrap()
532 .try_query(Some("key=bad-value"))
533 .unwrap();
534 let uri = builder.build().unwrap();
535 assert!(matches!(
536 DerivationPath::from_uri(&uri, true),
537 Err(DerivationPathError::InvalidDerivationPath(_))
538 ));
539
540 let mut builder = URIReferenceBuilder::new();
542 builder
543 .try_scheme(Some("test"))
544 .unwrap()
545 .try_authority(Some("path"))
546 .unwrap()
547 .try_path("")
548 .unwrap()
549 .try_query(Some("key="))
550 .unwrap();
551 let uri = builder.build().unwrap();
552 assert!(matches!(
553 DerivationPath::from_uri(&uri, true),
554 Err(DerivationPathError::InvalidDerivationPath(_))
555 ));
556
557 let mut builder = URIReferenceBuilder::new();
559 builder
560 .try_scheme(Some("test"))
561 .unwrap()
562 .try_authority(Some("path"))
563 .unwrap()
564 .try_path("")
565 .unwrap()
566 .try_query(Some("key"))
567 .unwrap();
568 let uri = builder.build().unwrap();
569 assert!(matches!(
570 DerivationPath::from_uri(&uri, true),
571 Err(DerivationPathError::InvalidDerivationPath(_))
572 ));
573 }
574
575 #[test]
576 fn test_from_uri_full_path() {
577 let derivation_path = DerivationPath::from_absolute_path_str("m/44'/999'/1'").unwrap();
578
579 let mut builder = URIReferenceBuilder::new();
581 builder
582 .try_scheme(Some("test"))
583 .unwrap()
584 .try_authority(Some("path"))
585 .unwrap()
586 .try_path("")
587 .unwrap()
588 .try_query(Some("full-path=m/44/999/1"))
589 .unwrap();
590 let uri = builder.build().unwrap();
591 assert_eq!(
592 DerivationPath::from_uri(&uri, false).unwrap(),
593 Some(derivation_path.clone())
594 );
595
596 let mut builder = URIReferenceBuilder::new();
598 builder
599 .try_scheme(Some("test"))
600 .unwrap()
601 .try_authority(Some("path"))
602 .unwrap()
603 .try_path("")
604 .unwrap()
605 .try_query(Some("full-path=m/44'/999'/1'"))
606 .unwrap();
607 let uri = builder.build().unwrap();
608 assert_eq!(
609 DerivationPath::from_uri(&uri, false).unwrap(),
610 Some(derivation_path.clone())
611 );
612
613 let mut builder = URIReferenceBuilder::new();
615 builder
616 .try_scheme(Some("test"))
617 .unwrap()
618 .try_authority(Some("path"))
619 .unwrap()
620 .try_path("")
621 .unwrap()
622 .try_query(Some("full-path=m/44\'/999\'/1\'"))
623 .unwrap();
624 let uri = builder.build().unwrap();
625 assert_eq!(
626 DerivationPath::from_uri(&uri, false).unwrap(),
627 Some(derivation_path)
628 );
629
630 let mut builder = URIReferenceBuilder::new();
632 builder
633 .try_scheme(Some("test"))
634 .unwrap()
635 .try_authority(Some("path"))
636 .unwrap()
637 .try_path("")
638 .unwrap()
639 .try_query(Some("full-path=m"))
640 .unwrap();
641 let uri = builder.build().unwrap();
642 assert_eq!(
643 DerivationPath::from_uri(&uri, false).unwrap(),
644 Some(DerivationPath(DerivationPathInner::from_str("m").unwrap()))
645 );
646
647 let mut builder = URIReferenceBuilder::new();
649 builder
650 .try_scheme(Some("test"))
651 .unwrap()
652 .try_authority(Some("path"))
653 .unwrap()
654 .try_path("")
655 .unwrap()
656 .try_query(Some("full-path=m/44/999/1"))
657 .unwrap();
658 let uri = builder.build().unwrap();
659 assert!(matches!(
660 DerivationPath::from_uri(&uri, true),
661 Err(DerivationPathError::InvalidDerivationPath(_))
662 ));
663
664 let mut builder = URIReferenceBuilder::new();
666 builder
667 .try_scheme(Some("test"))
668 .unwrap()
669 .try_authority(Some("path"))
670 .unwrap()
671 .try_path("")
672 .unwrap()
673 .try_query(Some("key=0/0&full-path=m/44/999/1"))
674 .unwrap();
675 let uri = builder.build().unwrap();
676 assert!(matches!(
677 DerivationPath::from_uri(&uri, false),
678 Err(DerivationPathError::InvalidDerivationPath(_))
679 ));
680
681 let mut builder = URIReferenceBuilder::new();
683 builder
684 .try_scheme(Some("test"))
685 .unwrap()
686 .try_authority(Some("path"))
687 .unwrap()
688 .try_path("")
689 .unwrap()
690 .try_query(Some("full-path=m/44/999/1&bad-key=0/0"))
691 .unwrap();
692 let uri = builder.build().unwrap();
693 assert!(matches!(
694 DerivationPath::from_uri(&uri, false),
695 Err(DerivationPathError::InvalidDerivationPath(_))
696 ));
697
698 let mut builder = URIReferenceBuilder::new();
700 builder
701 .try_scheme(Some("test"))
702 .unwrap()
703 .try_authority(Some("path"))
704 .unwrap()
705 .try_path("")
706 .unwrap()
707 .try_query(Some("full-path=bad-value"))
708 .unwrap();
709 let uri = builder.build().unwrap();
710 assert!(matches!(
711 DerivationPath::from_uri(&uri, false),
712 Err(DerivationPathError::InvalidDerivationPath(_))
713 ));
714
715 let mut builder = URIReferenceBuilder::new();
717 builder
718 .try_scheme(Some("test"))
719 .unwrap()
720 .try_authority(Some("path"))
721 .unwrap()
722 .try_path("")
723 .unwrap()
724 .try_query(Some("full-path="))
725 .unwrap();
726 let uri = builder.build().unwrap();
727 assert!(matches!(
728 DerivationPath::from_uri(&uri, false),
729 Err(DerivationPathError::InvalidDerivationPath(_))
730 ));
731
732 let mut builder = URIReferenceBuilder::new();
734 builder
735 .try_scheme(Some("test"))
736 .unwrap()
737 .try_authority(Some("path"))
738 .unwrap()
739 .try_path("")
740 .unwrap()
741 .try_query(Some("full-path"))
742 .unwrap();
743 let uri = builder.build().unwrap();
744 assert!(matches!(
745 DerivationPath::from_uri(&uri, false),
746 Err(DerivationPathError::InvalidDerivationPath(_))
747 ));
748 }
749
750 #[test]
751 fn test_get_query() {
752 let derivation_path = DerivationPath::new_bip44_with_coin(TestCoin, None, None);
753 assert_eq!(derivation_path.get_query(), "".to_string());
754 let derivation_path = DerivationPath::new_bip44_with_coin(TestCoin, Some(1), None);
755 assert_eq!(derivation_path.get_query(), "?key=1'".to_string());
756 let derivation_path = DerivationPath::new_bip44_with_coin(TestCoin, Some(1), Some(2));
757 assert_eq!(derivation_path.get_query(), "?key=1'/2'".to_string());
758 }
759
760 #[test]
761 fn test_derivation_path_debug() {
762 let path = DerivationPath::default();
763 assert_eq!(format!("{path:?}"), "m/44'/501'".to_string());
764
765 let path = DerivationPath::new_bip44(Some(1), None);
766 assert_eq!(format!("{path:?}"), "m/44'/501'/1'".to_string());
767
768 let path = DerivationPath::new_bip44(Some(1), Some(2));
769 assert_eq!(format!("{path:?}"), "m/44'/501'/1'/2'".to_string());
770 }
771}