1use tor_basic_utils::{n_key_list, n_key_set};
5use tor_llcrypto::pk::ed25519::Ed25519Identity;
6use tor_llcrypto::pk::rsa::RsaIdentity;
7
8use crate::{HasRelayIds, RelayIdRef};
9
10n_key_list! {
11 #[derive(Clone, Debug)]
23 pub struct[H:HasRelayIds] ListByRelayIds[H] for H
24 {
25 (Option) rsa: RsaIdentity { rsa_identity() },
26 (Option) ed25519: Ed25519Identity { ed_identity() },
27 }
28}
29
30n_key_set! {
31 #[derive(Clone, Debug)]
44 pub struct[H:HasRelayIds] ByRelayIds[H] for H
45 {
46 (Option) rsa: RsaIdentity { rsa_identity() },
47 (Option) ed25519: Ed25519Identity { ed_identity() },
48 }
49}
50
51impl<H: HasRelayIds> ByRelayIds<H> {
52 pub fn by_id<'a, T>(&self, key: T) -> Option<&H>
54 where
55 T: Into<RelayIdRef<'a>>,
56 {
57 match key.into() {
58 RelayIdRef::Ed25519(ed) => self.by_ed25519(ed),
59 RelayIdRef::Rsa(rsa) => self.by_rsa(rsa),
60 }
61 }
62
63 pub fn remove_by_id<'a, T>(&mut self, key: T) -> Option<H>
65 where
66 T: Into<RelayIdRef<'a>>,
67 {
68 match key.into() {
69 RelayIdRef::Ed25519(ed) => self.remove_by_ed25519(ed),
70 RelayIdRef::Rsa(rsa) => self.remove_by_rsa(rsa),
71 }
72 }
73
74 pub fn modify_by_id<'a, T, F>(&mut self, key: T, func: F) -> Vec<H>
78 where
79 T: Into<RelayIdRef<'a>>,
80 F: FnOnce(&mut H),
81 {
82 match key.into() {
83 RelayIdRef::Ed25519(ed) => self.modify_by_ed25519(ed, func),
84 RelayIdRef::Rsa(rsa) => self.modify_by_rsa(rsa, func),
85 }
86 }
87
88 pub fn by_all_ids<T>(&self, key: &T) -> Option<&H>
93 where
94 T: HasRelayIds,
95 {
96 let any_id = key.identities().next()?;
97 self.by_id(any_id)
98 .filter(|val| val.has_all_relay_ids_from(key))
99 }
100
101 pub fn modify_by_all_ids<T, F>(&mut self, key: &T, func: F) -> Vec<H>
106 where
107 T: HasRelayIds,
108 F: FnOnce(&mut H),
109 {
110 let any_id = match key.identities().next() {
111 Some(id) => id,
112 None => return Vec::new(),
113 };
114 self.modify_by_id(any_id, |val| {
115 if val.has_all_relay_ids_from(key) {
116 func(val);
117 }
118 })
119 }
120
121 pub fn remove_exact<T>(&mut self, key: &T) -> Option<H>
124 where
125 T: HasRelayIds,
126 {
127 let any_id = key.identities().next()?;
128 if self
129 .by_id(any_id)
130 .filter(|ent| ent.same_relay_ids(key))
131 .is_some()
132 {
133 self.remove_by_id(any_id)
134 } else {
135 None
136 }
137 }
138
139 pub fn remove_by_all_ids<T>(&mut self, key: &T) -> Option<H>
143 where
144 T: HasRelayIds,
145 {
146 let any_id = key.identities().next()?;
147 if self
148 .by_id(any_id)
149 .filter(|ent| ent.has_all_relay_ids_from(key))
150 .is_some()
151 {
152 self.remove_by_id(any_id)
153 } else {
154 None
155 }
156 }
157
158 pub fn all_overlapping<T>(&self, key: &T) -> Vec<&H>
163 where
164 T: HasRelayIds,
165 {
166 use by_address::ByAddress;
167 use std::collections::HashSet;
168
169 let mut items: HashSet<ByAddress<&H>> = HashSet::new();
170
171 for ident in key.identities() {
172 if let Some(found) = self.by_id(ident) {
173 items.insert(ByAddress(found));
174 }
175 }
176
177 items.into_iter().map(|by_addr| by_addr.0).collect()
178 }
179}
180
181impl<H: HasRelayIds> ListByRelayIds<H> {
182 pub fn by_id<'a, T>(&self, key: T) -> ListByRelayIdsIter<H>
184 where
185 T: Into<RelayIdRef<'a>>,
186 {
187 match key.into() {
188 RelayIdRef::Ed25519(ed) => self.by_ed25519(ed),
189 RelayIdRef::Rsa(rsa) => self.by_rsa(rsa),
190 }
191 }
192
193 pub fn by_all_ids<'a>(&'a self, key: &'a impl HasRelayIds) -> impl Iterator<Item = &'a H> + 'a {
197 key.identities()
198 .next()
199 .map_or_else(Default::default, |id| self.by_id(id))
200 .filter(|val| val.has_all_relay_ids_from(key))
201 }
202
203 pub fn all_overlapping<T>(&self, key: &T) -> Vec<&H>
208 where
209 T: HasRelayIds,
210 {
211 use by_address::ByAddress;
212 use std::collections::HashSet;
213
214 let mut items: HashSet<ByAddress<&H>> = HashSet::new();
215
216 for ident in key.identities() {
217 for found in self.by_id(ident) {
218 items.insert(ByAddress(found));
219 }
220 }
221
222 items.into_iter().map(|by_addr| by_addr.0).collect()
223 }
224
225 pub fn all_subset<T>(&self, key: &T) -> Vec<&H>
231 where
232 T: HasRelayIds,
233 {
234 use by_address::ByAddress;
235 use std::collections::HashSet;
236
237 let mut items: HashSet<ByAddress<&H>> = HashSet::new();
238
239 for ident in key.identities() {
240 for found in self.by_id(ident) {
241 if key.has_all_relay_ids_from(found) {
243 items.insert(ByAddress(found));
244 }
245 }
246 }
247
248 items.into_iter().map(|by_addr| by_addr.0).collect()
249 }
250
251 pub fn remove_by_id<'a, T>(&mut self, key: T, filter: impl FnMut(&H) -> bool) -> Vec<H>
253 where
254 T: Into<RelayIdRef<'a>>,
255 {
256 match key.into() {
257 RelayIdRef::Ed25519(ed) => self.remove_by_ed25519(ed, filter),
258 RelayIdRef::Rsa(rsa) => self.remove_by_rsa(rsa, filter),
259 }
260 }
261
262 pub fn remove_exact<T>(&mut self, key: &T) -> Vec<H>
265 where
266 T: HasRelayIds,
267 {
268 let Some(id) = key.identities().next() else {
269 return Vec::new();
270 };
271
272 self.remove_by_id(id, |val| val.same_relay_ids(key))
273 }
274
275 pub fn remove_by_all_ids<T>(&mut self, key: &T) -> Vec<H>
279 where
280 T: HasRelayIds,
281 {
282 let Some(id) = key.identities().next() else {
283 return Vec::new();
284 };
285
286 self.remove_by_id(id, |val| val.has_all_relay_ids_from(key))
287 }
288}
289
290pub use tor_basic_utils::n_key_list::Error as ListByRelayIdsError;
291pub use tor_basic_utils::n_key_set::Error as ByRelayIdsError;
292
293#[cfg(test)]
294mod test {
295 #![allow(clippy::bool_assert_comparison)]
297 #![allow(clippy::clone_on_copy)]
298 #![allow(clippy::dbg_macro)]
299 #![allow(clippy::mixed_attributes_style)]
300 #![allow(clippy::print_stderr)]
301 #![allow(clippy::print_stdout)]
302 #![allow(clippy::single_char_pattern)]
303 #![allow(clippy::unwrap_used)]
304 #![allow(clippy::unchecked_time_subtraction)]
305 #![allow(clippy::useless_vec)]
306 #![allow(clippy::needless_pass_by_value)]
307 #![allow(clippy::string_slice)] use super::*;
311 use crate::{RelayIds, RelayIdsBuilder};
312
313 fn sort<T: std::cmp::Ord>(i: impl Iterator<Item = T>) -> Vec<T> {
314 let mut v: Vec<_> = i.collect();
315 v.sort();
316 v
317 }
318
319 #[test]
320 #[allow(clippy::cognitive_complexity)]
321 fn lookup() {
322 let rsa1: RsaIdentity = (*b"12345678901234567890").into();
323 let rsa2: RsaIdentity = (*b"abcefghijklmnopqrstu").into();
324 let rsa3: RsaIdentity = (*b"abcefghijklmnopQRSTU").into();
325 let ed1: Ed25519Identity = (*b"12345678901234567890123456789012").into();
326 let ed2: Ed25519Identity = (*b"abcefghijklmnopqrstuvwxyzABCDEFG").into();
327 let ed3: Ed25519Identity = (*b"abcefghijklmnopqrstuvwxyz1234567").into();
328
329 let keys1 = RelayIdsBuilder::default()
330 .rsa_identity(rsa1)
331 .ed_identity(ed1)
332 .build()
333 .unwrap();
334
335 let keys2 = RelayIdsBuilder::default()
336 .rsa_identity(rsa2)
337 .ed_identity(ed2)
338 .build()
339 .unwrap();
340
341 let mut set = ByRelayIds::new();
345 set.insert(keys1.clone());
346 set.insert(keys2.clone());
347
348 let mut list = ListByRelayIds::new();
349 list.insert(keys1.clone());
350 list.insert(keys2.clone());
351
352 assert_eq!(set.by_id(&rsa1), Some(&keys1));
354 assert_eq!(set.by_id(&ed1), Some(&keys1));
355 assert_eq!(set.by_id(&rsa2), Some(&keys2));
356 assert_eq!(set.by_id(&ed2), Some(&keys2));
357 assert_eq!(set.by_id(&rsa3), None);
358 assert_eq!(set.by_id(&ed3), None);
359 assert_eq!(sort(list.by_id(&rsa1)), [&keys1]);
360 assert_eq!(sort(list.by_id(&ed1)), [&keys1]);
361 assert_eq!(sort(list.by_id(&rsa2)), [&keys2]);
362 assert_eq!(sort(list.by_id(&ed2)), [&keys2]);
363 assert_eq!(list.by_id(&rsa3).len(), 0);
364 assert_eq!(list.by_id(&ed3).len(), 0);
365
366 assert_eq!(set.by_all_ids(&keys1), Some(&keys1));
368 assert_eq!(set.by_all_ids(&keys2), Some(&keys2));
369 assert_eq!(set.by_all_ids(&RelayIds::empty()), None);
370 assert_eq!(sort(list.by_all_ids(&keys1)), [&keys1]);
371 assert_eq!(sort(list.by_all_ids(&keys2)), [&keys2]);
372 assert!(sort(list.by_all_ids(&RelayIds::empty())).is_empty());
373 {
374 let search = RelayIdsBuilder::default()
375 .rsa_identity(rsa1)
376 .build()
377 .unwrap();
378 assert_eq!(set.by_all_ids(&search), Some(&keys1));
379 assert_eq!(sort(list.by_all_ids(&search)), [&keys1]);
380 }
381 {
382 let search = RelayIdsBuilder::default()
383 .rsa_identity(rsa1)
384 .ed_identity(ed2)
385 .build()
386 .unwrap();
387 assert_eq!(set.by_all_ids(&search), None);
388 assert!(sort(list.by_all_ids(&search)).is_empty());
389 }
390
391 assert_eq!(set.all_overlapping(&keys1), vec![&keys1]);
393 assert_eq!(set.all_overlapping(&keys2), vec![&keys2]);
394 assert_eq!(list.all_overlapping(&keys1), vec![&keys1]);
395 assert_eq!(list.all_overlapping(&keys2), vec![&keys2]);
396 {
397 let search = RelayIdsBuilder::default()
398 .rsa_identity(rsa1)
399 .ed_identity(ed2)
400 .build()
401 .unwrap();
402 let answer = set.all_overlapping(&search);
403 assert_eq!(answer.len(), 2);
404 assert!(answer.contains(&&keys1));
405 assert!(answer.contains(&&keys2));
406 let answer = list.all_overlapping(&search);
407 assert_eq!(answer.len(), 2);
408 assert!(answer.contains(&&keys1));
409 assert!(answer.contains(&&keys2));
410 }
411 {
412 let search = RelayIdsBuilder::default()
413 .rsa_identity(rsa2)
414 .build()
415 .unwrap();
416 assert_eq!(set.all_overlapping(&search), vec![&keys2]);
417 assert_eq!(list.all_overlapping(&search), vec![&keys2]);
418 }
419 {
420 let search = RelayIdsBuilder::default()
421 .rsa_identity(rsa3)
422 .build()
423 .unwrap();
424 assert!(set.all_overlapping(&search).is_empty());
425 assert!(list.all_overlapping(&search).is_empty());
426 }
427 }
428
429 #[test]
430 #[allow(clippy::cognitive_complexity)]
431 fn remove_exact() {
432 let rsa1: RsaIdentity = (*b"12345678901234567890").into();
433 let rsa2: RsaIdentity = (*b"abcefghijklmnopqrstu").into();
434 let ed1: Ed25519Identity = (*b"12345678901234567890123456789012").into();
435 let ed2: Ed25519Identity = (*b"abcefghijklmnopqrstuvwxyzABCDEFG").into();
436
437 let keys1 = RelayIdsBuilder::default()
438 .rsa_identity(rsa1)
439 .ed_identity(ed1)
440 .build()
441 .unwrap();
442
443 let keys2 = RelayIdsBuilder::default()
444 .rsa_identity(rsa2)
445 .ed_identity(ed2)
446 .build()
447 .unwrap();
448
449 let mut set = ByRelayIds::new();
453 set.insert(keys1.clone());
454 set.insert(keys2.clone());
455 assert_eq!(set.len(), 2);
456
457 let mut list = ListByRelayIds::new();
458 list.insert(keys1.clone());
459 list.insert(keys2.clone());
460 assert_eq!(list.len(), 2);
461
462 assert_eq!(set.remove_exact(&keys1), Some(keys1.clone()));
463 assert_eq!(set.len(), 1);
464 assert_eq!(list.remove_exact(&keys1), vec![keys1.clone()]);
465 assert_eq!(list.len(), 1);
466
467 {
468 let search = RelayIdsBuilder::default().ed_identity(ed2).build().unwrap();
469
470 assert_eq!(set.remove_exact(&search), None);
472 assert_eq!(set.len(), 1);
473 assert_eq!(list.remove_exact(&search), vec![]);
474 assert_eq!(list.len(), 1);
475
476 let no_match = RelayIdsBuilder::default()
479 .ed_identity(ed2)
480 .rsa_identity(rsa1)
481 .build()
482 .unwrap();
483 assert_eq!(set.remove_by_all_ids(&no_match), None);
484 assert_eq!(set.len(), 1);
485 assert_eq!(list.remove_by_all_ids(&no_match), vec![]);
486 assert_eq!(list.len(), 1);
487
488 assert_eq!(set.remove_by_all_ids(&search), Some(keys2.clone()));
491 assert!(set.is_empty());
492 assert_eq!(list.remove_by_all_ids(&search), vec![keys2.clone()]);
493 assert!(list.is_empty());
494 }
495 }
496
497 #[test]
498 #[allow(clippy::cognitive_complexity)]
499 fn all_subset() {
500 let rsa1: RsaIdentity = (*b"12345678901234567890").into();
501 let rsa2: RsaIdentity = (*b"abcefghijklmnopqrstu").into();
502 let ed1: Ed25519Identity = (*b"12345678901234567890123456789012").into();
503
504 let keys1 = RelayIdsBuilder::default()
506 .rsa_identity(rsa1)
507 .ed_identity(ed1)
508 .build()
509 .unwrap();
510
511 let keys2 = RelayIdsBuilder::default()
513 .rsa_identity(rsa2)
514 .build()
515 .unwrap();
516
517 let mut list = ListByRelayIds::new();
518 list.insert(keys1.clone());
519 list.insert(keys2.clone());
520
521 assert_eq!(list.all_subset(&keys1), vec![&keys1]);
522 assert_eq!(list.all_subset(&keys2), vec![&keys2]);
523
524 {
525 let search = RelayIdsBuilder::default()
526 .rsa_identity(rsa1)
527 .build()
528 .unwrap();
529 assert!(list.all_subset(&search).is_empty());
530 }
531
532 {
533 let search = RelayIdsBuilder::default().ed_identity(ed1).build().unwrap();
534 assert!(list.all_subset(&search).is_empty());
535 }
536
537 {
538 let search = RelayIdsBuilder::default()
539 .rsa_identity(rsa2)
540 .build()
541 .unwrap();
542 assert_eq!(list.all_subset(&search), vec![&keys2]);
543 }
544
545 {
546 let search = RelayIdsBuilder::default()
547 .ed_identity(ed1)
548 .rsa_identity(rsa2)
549 .build()
550 .unwrap();
551 assert_eq!(list.all_subset(&search), vec![&keys2]);
552 }
553 }
554
555 #[test]
556 #[allow(clippy::cognitive_complexity)]
557 fn list_by_relay_ids() {
558 #[derive(Clone, Debug)]
559 struct ErsatzChannel<T> {
560 val: T,
561 ids: RelayIds,
562 }
563
564 impl<T> ErsatzChannel<T> {
565 fn new(val: T, ids: RelayIds) -> Self {
566 Self { val, ids }
567 }
568 }
569
570 impl<T> HasRelayIds for ErsatzChannel<T> {
571 fn identity(&self, key_type: crate::RelayIdType) -> Option<RelayIdRef<'_>> {
572 self.ids.identity(key_type)
573 }
574 }
575
576 fn ids(
578 rsa: impl Into<Option<RsaIdentity>>,
579 ed: impl Into<Option<Ed25519Identity>>,
580 ) -> RelayIds {
581 let mut ids = RelayIdsBuilder::default();
582 if let Some(rsa) = rsa.into() {
583 ids.rsa_identity(rsa);
584 }
585 if let Some(ed) = ed.into() {
586 ids.ed_identity(ed);
587 }
588 ids.build().unwrap()
589 }
590
591 let rsa_a: RsaIdentity = (*b"12345678901234567890").into();
593 let ed_a: Ed25519Identity = (*b"12345678901234567890123456789012").into();
594
595 let ed_b: Ed25519Identity = (*b"abcefghijklmnopqrstuvwxyzABCDEFG").into();
597 let rsa_b: RsaIdentity = (*b"abcefghijklmnopqrstu").into();
598
599 let channel_a_all = ErsatzChannel::new("channel-a-all", ids(rsa_a, ed_a));
601
602 let channel_a_rsa_only_1 = ErsatzChannel::new("channel-a-rsa-only-1", ids(rsa_a, None));
604
605 let channel_a_rsa_only_2 = ErsatzChannel::new("channel-a-rsa-only-2", ids(rsa_a, None));
608
609 let channel_a_ed_only = ErsatzChannel::new("channel-a-ed-only", ids(None, ed_a));
611
612 let channel_b_all = ErsatzChannel::new("channel-b-all", ids(rsa_b, ed_b));
614
615 let channel_invalid = ErsatzChannel::new("channel-invalid", ids(rsa_a, ed_b));
618
619 let mut list = ListByRelayIds::new();
620 list.insert(channel_a_all.clone());
621 list.insert(channel_a_rsa_only_1.clone());
622 list.insert(channel_a_rsa_only_2.clone());
623 list.insert(channel_a_ed_only.clone());
624 list.insert(channel_b_all.clone());
625 list.insert(channel_invalid.clone());
626
627 assert_eq!(
629 sort(list.by_id(&rsa_a).map(|x| x.val)),
630 [
631 "channel-a-all",
632 "channel-a-rsa-only-1",
633 "channel-a-rsa-only-2",
634 "channel-invalid",
635 ],
636 );
637
638 assert_eq!(
640 sort(list.by_id(&ed_a).map(|x| x.val)),
641 ["channel-a-all", "channel-a-ed-only"],
642 );
643
644 assert_eq!(sort(list.by_id(&rsa_b).map(|x| x.val)), ["channel-b-all"]);
646
647 assert_eq!(
649 sort(list.by_id(&ed_b).map(|x| x.val)),
650 ["channel-b-all", "channel-invalid"],
651 );
652
653 assert_eq!(
655 sort(list.by_all_ids(&ids(rsa_a, ed_a)).map(|x| x.val)),
656 ["channel-a-all"],
657 );
658
659 assert_eq!(
661 sort(list.by_all_ids(&ids(rsa_b, ed_b)).map(|x| x.val)),
662 ["channel-b-all"],
663 );
664
665 assert_eq!(
667 sort(
668 list.all_overlapping(&ids(rsa_a, ed_a))
669 .into_iter()
670 .map(|x| x.val)
671 ),
672 [
673 "channel-a-all",
674 "channel-a-ed-only",
675 "channel-a-rsa-only-1",
676 "channel-a-rsa-only-2",
677 "channel-invalid",
678 ],
679 );
680
681 assert_eq!(
683 sort(
684 list.all_subset(&ids(rsa_a, ed_a))
685 .into_iter()
686 .map(|x| x.val)
687 ),
688 [
689 "channel-a-all",
690 "channel-a-ed-only",
691 "channel-a-rsa-only-1",
692 "channel-a-rsa-only-2",
693 ],
694 );
695
696 assert_eq!(list.by_all_ids(&ids(None, None)).count(), 0);
698 assert!(list.all_overlapping(&ids(None, None)).is_empty());
699 assert!(list.all_subset(&ids(None, None)).is_empty());
700 assert_eq!(
701 sort(
702 list.all_overlapping(&ids(rsa_a, None))
703 .into_iter()
704 .map(|x| x.val)
705 ),
706 sort(list.by_id(&rsa_a).map(|x| x.val)),
707 );
708 assert_eq!(
709 sort(
710 list.all_overlapping(&ids(None, ed_b))
711 .into_iter()
712 .map(|x| x.val)
713 ),
714 sort(list.by_id(&ed_b).map(|x| x.val)),
715 );
716 assert_eq!(
717 sort(list.by_id(&rsa_a).map(|x| x.val)),
718 sort(list.by_rsa(&rsa_a).map(|x| x.val)),
719 );
720 assert_eq!(
721 sort(list.by_id(&ed_a).map(|x| x.val)),
722 sort(list.by_ed25519(&ed_a).map(|x| x.val)),
723 );
724
725 {
727 let mut list = list.clone();
728 assert_eq!(
729 sort(
730 list.remove_exact(&ids(rsa_a, ed_a))
731 .into_iter()
732 .map(|x| x.val)
733 ),
734 ["channel-a-all"],
735 );
736 assert_eq!(list.by_all_ids(&ids(rsa_a, ed_a)).count(), 0);
737 }
738
739 {
741 let mut list = list.clone();
742 assert_eq!(
743 sort(
744 list.remove_exact(&ids(rsa_a, None))
745 .into_iter()
746 .map(|x| x.val)
747 ),
748 ["channel-a-rsa-only-1", "channel-a-rsa-only-2"],
749 );
750 assert_eq!(
751 sort(list.by_all_ids(&ids(rsa_a, None)).map(|x| x.val)),
752 ["channel-a-all", "channel-invalid"],
753 );
754 }
755
756 {
758 let mut list = list.clone();
759 assert_eq!(
760 sort(
761 list.remove_by_all_ids(&ids(rsa_a, None))
762 .into_iter()
763 .map(|x| x.val)
764 ),
765 [
766 "channel-a-all",
767 "channel-a-rsa-only-1",
768 "channel-a-rsa-only-2",
769 "channel-invalid",
770 ],
771 );
772 assert_eq!(list.by_all_ids(&ids(rsa_a, None)).count(), 0);
773 }
774 }
775}