1use crate::params::NetParameters;
16use crate::ConsensusRelays;
17use bitflags::bitflags;
18use tor_netdoc::doc::netstatus::{self, MdConsensus, MdConsensusRouterStatus, NetParams};
19
20fn pick_bandwidth_fn<'a, I>(mut weights: I) -> BandwidthFn
23where
24 I: Clone + Iterator<Item = &'a netstatus::RelayWeight>,
25{
26 let has_measured = weights.clone().any(|w| w.is_measured());
27 let has_nonzero = weights.clone().any(|w| w.is_nonzero());
28 let has_nonzero_measured = weights.any(|w| w.is_measured() && w.is_nonzero());
29
30 if !has_nonzero {
31 BandwidthFn::Uniform
34 } else if !has_measured {
35 BandwidthFn::IncludeUnmeasured
38 } else if has_nonzero_measured {
39 BandwidthFn::MeasuredOnly
42 } else {
43 BandwidthFn::Uniform
47 }
48}
49
50#[derive(Copy, Clone, Debug, PartialEq, Eq)]
54enum BandwidthFn {
55 Uniform,
58 IncludeUnmeasured,
61 MeasuredOnly,
63}
64
65impl BandwidthFn {
66 fn apply(&self, w: &netstatus::RelayWeight) -> u32 {
69 use netstatus::RelayWeight::*;
70 use BandwidthFn::*;
71 match (self, w) {
72 (Uniform, _) => 1,
73 (IncludeUnmeasured, Unmeasured(u)) => *u,
74 (IncludeUnmeasured, Measured(m)) => *m,
75 (MeasuredOnly, Unmeasured(_)) => 0,
76 (MeasuredOnly, Measured(m)) => *m,
77 (_, _) => 0,
78 }
79 }
80}
81
82#[derive(Clone, Debug, Copy)]
89#[non_exhaustive]
90pub enum WeightRole {
91 Guard,
93 Middle,
95 Exit,
97 BeginDir,
99 Unweighted,
101 HsIntro,
103 }
109
110#[derive(Clone, Debug, Copy)]
112struct RelayWeight {
113 as_guard: u32,
115 as_middle: u32,
117 as_exit: u32,
119 as_dir: u32,
121}
122
123impl std::ops::Mul<u32> for RelayWeight {
124 type Output = Self;
125 fn mul(self, rhs: u32) -> Self {
126 RelayWeight {
127 as_guard: self.as_guard * rhs,
128 as_middle: self.as_middle * rhs,
129 as_exit: self.as_exit * rhs,
130 as_dir: self.as_dir * rhs,
131 }
132 }
133}
134impl std::ops::Div<u32> for RelayWeight {
135 type Output = Self;
136 fn div(self, rhs: u32) -> Self {
137 RelayWeight {
138 as_guard: self.as_guard / rhs,
139 as_middle: self.as_middle / rhs,
140 as_exit: self.as_exit / rhs,
141 as_dir: self.as_dir / rhs,
142 }
143 }
144}
145
146impl RelayWeight {
147 #[allow(clippy::unwrap_used)]
150 fn max_weight(&self) -> u32 {
151 [self.as_guard, self.as_middle, self.as_exit, self.as_dir]
152 .iter()
153 .max()
154 .copied()
155 .unwrap()
156 }
157 fn for_role(&self, role: WeightRole) -> u32 {
160 match role {
161 WeightRole::Guard => self.as_guard,
162 WeightRole::Middle => self.as_middle,
163 WeightRole::Exit => self.as_exit,
164 WeightRole::BeginDir => self.as_dir,
165 WeightRole::HsIntro => self.as_middle, WeightRole::Unweighted => 1,
167 }
168 }
169}
170
171bitflags! {
172 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
177 struct WeightKind: u8 {
178 const GUARD = 1 << 0;
180 const EXIT = 1 << 1;
182 const DIR = 1 << 2;
184 }
185}
186
187impl WeightKind {
188 fn for_rs(rs: &MdConsensusRouterStatus) -> Self {
190 let mut r = WeightKind::empty();
191 if rs.is_flagged_guard() {
192 r |= WeightKind::GUARD;
193 }
194 if rs.is_flagged_exit() {
195 r |= WeightKind::EXIT;
196 }
197 if rs.is_flagged_v2dir() {
198 r |= WeightKind::DIR;
199 }
200 r
201 }
202 fn idx(self) -> usize {
204 self.bits() as usize
205 }
206}
207
208#[derive(Debug, Clone)]
211pub(crate) struct WeightSet {
212 bandwidth_fn: BandwidthFn,
218 shift: u8,
229 w: [RelayWeight; 8],
232}
233
234impl WeightSet {
235 pub(crate) fn weight_rs_for_role(&self, rs: &MdConsensusRouterStatus, role: WeightRole) -> u64 {
242 self.weight_bw_for_role(WeightKind::for_rs(rs), rs.weight(), role)
243 }
244
245 fn weight_bw_for_role(
248 &self,
249 kind: WeightKind,
250 relay_weight: &netstatus::RelayWeight,
251 role: WeightRole,
252 ) -> u64 {
253 let ws = &self.w[kind.idx()];
254
255 let router_bw = self.bandwidth_fn.apply(relay_weight);
256 let router_weight = u64::from(router_bw) * u64::from(ws.for_role(role));
261 router_weight >> self.shift
262 }
263
264 pub(crate) fn from_consensus(consensus: &MdConsensus, params: &NetParameters) -> Self {
266 let bandwidth_fn = pick_bandwidth_fn(consensus.c_relays().iter().map(|rs| rs.weight()));
267 let weight_scale = params.bw_weight_scale.into();
268
269 let total_bw = consensus
270 .c_relays()
271 .iter()
272 .map(|rs| u64::from(bandwidth_fn.apply(rs.weight())))
273 .sum();
274 let p = consensus.bandwidth_weights();
275
276 Self::from_parts(bandwidth_fn, total_bw, weight_scale, p).validate(consensus)
277 }
278
279 fn from_parts(
283 bandwidth_fn: BandwidthFn,
284 total_bw: u64,
285 weight_scale: u32,
286 p: &NetParams<i32>,
287 ) -> Self {
288 #[allow(clippy::many_single_char_names)]
294 fn single(p: &NetParams<i32>, g: &str, m: &str, e: &str, d: &str) -> RelayWeight {
295 RelayWeight {
296 as_guard: w_param(p, g),
297 as_middle: w_param(p, m),
298 as_exit: w_param(p, e),
299 as_dir: w_param(p, d),
300 }
301 }
302
303 let weight_scale = weight_scale.max(1);
306
307 let w_none = single(p, "Wgm", "Wmm", "Wem", "Wbm");
313 let w_guard = single(p, "Wgg", "Wmg", "Weg", "Wbg");
314 let w_exit = single(p, "---", "Wme", "Wee", "Wbe");
315 let w_both = single(p, "Wgd", "Wmd", "Wed", "Wbd");
316
317 let w = [
320 w_none,
321 w_guard,
322 w_exit,
323 w_both,
324 (w_none * w_param(p, "Wmb")) / weight_scale,
330 (w_guard * w_param(p, "Wgb")) / weight_scale,
331 (w_exit * w_param(p, "Web")) / weight_scale,
332 (w_both * w_param(p, "Wdb")) / weight_scale,
333 ];
334
335 #[allow(clippy::unwrap_used)]
338 let w_max = w.iter().map(RelayWeight::max_weight).max().unwrap();
339
340 let shift = calculate_shift(total_bw, u64::from(w_max)) as u8;
342
343 WeightSet {
344 bandwidth_fn,
345 shift,
346 w,
347 }
348 }
349
350 fn validate(self, consensus: &MdConsensus) -> Self {
353 use WeightRole::*;
354 for role in [Guard, Middle, Exit, BeginDir, Unweighted] {
355 let _: u64 = consensus
356 .c_relays()
357 .iter()
358 .map(|rs| self.weight_rs_for_role(rs, role))
359 .fold(0_u64, |a, b| {
360 a.checked_add(b)
361 .expect("Incorrect relay weight calculation: total exceeded u64::MAX!")
362 });
363 }
364 self
365 }
366}
367
368const DFLT_WEIGHT: i32 = 1;
375
376fn w_param(p: &NetParams<i32>, kwd: &str) -> u32 {
381 if kwd == "---" {
382 0
383 } else {
384 clamp_to_pos(*p.get(kwd).unwrap_or(&DFLT_WEIGHT))
385 }
386}
387
388fn clamp_to_pos(inp: i32) -> u32 {
390 if inp < 0 {
393 0
394 } else {
395 inp as u32
396 }
397}
398
399fn calculate_shift(a: u64, b: u64) -> u32 {
402 let bits_for_product = log2_upper(a) + log2_upper(b);
403 if bits_for_product < 64 {
404 0
405 } else {
406 bits_for_product - 64
407 }
408}
409
410fn log2_upper(n: u64) -> u32 {
415 64 - n.leading_zeros()
416}
417
418#[cfg(test)]
419mod test {
420 #![allow(clippy::bool_assert_comparison)]
422 #![allow(clippy::clone_on_copy)]
423 #![allow(clippy::dbg_macro)]
424 #![allow(clippy::mixed_attributes_style)]
425 #![allow(clippy::print_stderr)]
426 #![allow(clippy::print_stdout)]
427 #![allow(clippy::single_char_pattern)]
428 #![allow(clippy::unwrap_used)]
429 #![allow(clippy::unchecked_duration_subtraction)]
430 #![allow(clippy::useless_vec)]
431 #![allow(clippy::needless_pass_by_value)]
432 use super::*;
434 use netstatus::RelayWeight as RW;
435 use std::net::SocketAddr;
436 use std::time::{Duration, SystemTime};
437 use tor_basic_utils::test_rng::testing_rng;
438 use tor_netdoc::doc::netstatus::{Lifetime, RelayFlags, RouterStatusBuilder};
439
440 #[test]
441 fn t_clamp() {
442 assert_eq!(clamp_to_pos(32), 32);
443 assert_eq!(clamp_to_pos(i32::MAX), i32::MAX as u32);
444 assert_eq!(clamp_to_pos(0), 0);
445 assert_eq!(clamp_to_pos(-1), 0);
446 assert_eq!(clamp_to_pos(i32::MIN), 0);
447 }
448
449 #[test]
450 fn t_log2() {
451 assert_eq!(log2_upper(u64::MAX), 64);
452 assert_eq!(log2_upper(0), 0);
453 assert_eq!(log2_upper(1), 1);
454 assert_eq!(log2_upper(63), 6);
455 assert_eq!(log2_upper(64), 7); }
457
458 #[test]
459 fn t_calc_shift() {
460 assert_eq!(calculate_shift(1 << 20, 1 << 20), 0);
461 assert_eq!(calculate_shift(1 << 50, 1 << 10), 0);
462 assert_eq!(calculate_shift(1 << 32, 1 << 33), 3);
463 assert!(((1_u64 << 32) >> 3).checked_mul(1_u64 << 33).is_some());
464 assert_eq!(calculate_shift(432 << 40, 7777 << 40), 38);
465 assert!(((432_u64 << 40) >> 38)
466 .checked_mul(7777_u64 << 40)
467 .is_some());
468 }
469
470 #[test]
471 fn t_pick_bwfunc() {
472 let empty = [];
473 assert_eq!(pick_bandwidth_fn(empty.iter()), BandwidthFn::Uniform);
474
475 let all_zero = [RW::Unmeasured(0), RW::Measured(0), RW::Unmeasured(0)];
476 assert_eq!(pick_bandwidth_fn(all_zero.iter()), BandwidthFn::Uniform);
477
478 let all_unmeasured = [RW::Unmeasured(9), RW::Unmeasured(2222)];
479 assert_eq!(
480 pick_bandwidth_fn(all_unmeasured.iter()),
481 BandwidthFn::IncludeUnmeasured
482 );
483
484 let some_measured = [
485 RW::Unmeasured(10),
486 RW::Measured(7),
487 RW::Measured(4),
488 RW::Unmeasured(0),
489 ];
490 assert_eq!(
491 pick_bandwidth_fn(some_measured.iter()),
492 BandwidthFn::MeasuredOnly
493 );
494
495 let measured_all_zero = [RW::Unmeasured(10), RW::Measured(0)];
499 assert_eq!(
500 pick_bandwidth_fn(measured_all_zero.iter()),
501 BandwidthFn::Uniform
502 );
503 }
504
505 #[test]
506 fn t_apply_bwfn() {
507 use netstatus::RelayWeight::*;
508 use BandwidthFn::*;
509
510 assert_eq!(Uniform.apply(&Measured(7)), 1);
511 assert_eq!(Uniform.apply(&Unmeasured(0)), 1);
512
513 assert_eq!(IncludeUnmeasured.apply(&Measured(7)), 7);
514 assert_eq!(IncludeUnmeasured.apply(&Unmeasured(8)), 8);
515
516 assert_eq!(MeasuredOnly.apply(&Measured(9)), 9);
517 assert_eq!(MeasuredOnly.apply(&Unmeasured(10)), 0);
518 }
519
520 const TESTVEC_PARAMS: &str =
522 "Wbd=0 Wbe=0 Wbg=4096 Wbm=10000 Wdb=10000 Web=10000 Wed=10000 Wee=10000 Weg=10000 Wem=10000 Wgb=10000 Wgd=0 Wgg=5904 Wgm=5904 Wmb=10000 Wmd=0 Wme=0 Wmg=4096 Wmm=10000";
523
524 #[test]
525 fn t_weightset_basic() {
526 let total_bandwidth = 1_000_000_000;
527 let params = TESTVEC_PARAMS.parse().unwrap();
528 let ws = WeightSet::from_parts(BandwidthFn::MeasuredOnly, total_bandwidth, 10000, ¶ms);
529
530 assert_eq!(ws.bandwidth_fn, BandwidthFn::MeasuredOnly);
531 assert_eq!(ws.shift, 0);
532
533 assert_eq!(ws.w[0].as_guard, 5904);
534 assert_eq!(ws.w[(WeightKind::GUARD.bits()) as usize].as_guard, 5904);
535 assert_eq!(ws.w[(WeightKind::EXIT.bits()) as usize].as_exit, 10000);
536 assert_eq!(
537 ws.w[(WeightKind::EXIT | WeightKind::GUARD).bits() as usize].as_dir,
538 0
539 );
540 assert_eq!(
541 ws.w[(WeightKind::GUARD | WeightKind::DIR).bits() as usize].as_dir,
542 4096
543 );
544 assert_eq!(
545 ws.w[(WeightKind::GUARD | WeightKind::DIR).bits() as usize].as_dir,
546 4096
547 );
548
549 assert_eq!(
550 ws.weight_bw_for_role(
551 WeightKind::GUARD | WeightKind::DIR,
552 &RW::Unmeasured(7777),
553 WeightRole::Guard
554 ),
555 0
556 );
557
558 assert_eq!(
559 ws.weight_bw_for_role(
560 WeightKind::GUARD | WeightKind::DIR,
561 &RW::Measured(7777),
562 WeightRole::Guard
563 ),
564 7777 * 5904
565 );
566
567 assert_eq!(
568 ws.weight_bw_for_role(
569 WeightKind::GUARD | WeightKind::DIR,
570 &RW::Measured(7777),
571 WeightRole::Middle
572 ),
573 7777 * 4096
574 );
575
576 assert_eq!(
577 ws.weight_bw_for_role(
578 WeightKind::GUARD | WeightKind::DIR,
579 &RW::Measured(7777),
580 WeightRole::Exit
581 ),
582 7777 * 10000
583 );
584
585 assert_eq!(
586 ws.weight_bw_for_role(
587 WeightKind::GUARD | WeightKind::DIR,
588 &RW::Measured(7777),
589 WeightRole::BeginDir
590 ),
591 7777 * 4096
592 );
593
594 assert_eq!(
595 ws.weight_bw_for_role(
596 WeightKind::GUARD | WeightKind::DIR,
597 &RW::Measured(7777),
598 WeightRole::Unweighted
599 ),
600 7777
601 );
602
603 let rs = rs_builder()
605 .set_flags(RelayFlags::GUARD | RelayFlags::V2DIR)
606 .weight(RW::Measured(7777))
607 .build()
608 .unwrap();
609 assert_eq!(ws.weight_rs_for_role(&rs, WeightRole::Exit), 7777 * 10000);
610 assert_eq!(
611 ws.weight_rs_for_role(&rs, WeightRole::BeginDir),
612 7777 * 4096
613 );
614 assert_eq!(ws.weight_rs_for_role(&rs, WeightRole::Unweighted), 7777);
615 }
616
617 fn rs_builder() -> RouterStatusBuilder<[u8; 32]> {
620 MdConsensus::builder()
621 .rs()
622 .identity([9; 20].into())
623 .add_or_port(SocketAddr::from(([127, 0, 0, 1], 9001)))
624 .doc_digest([9; 32])
625 .protos("".parse().unwrap())
626 .clone()
627 }
628
629 #[test]
630 fn weight_flags() {
631 let rs1 = rs_builder().set_flags(RelayFlags::EXIT).build().unwrap();
632 assert_eq!(WeightKind::for_rs(&rs1), WeightKind::EXIT);
633
634 let rs1 = rs_builder().set_flags(RelayFlags::GUARD).build().unwrap();
635 assert_eq!(WeightKind::for_rs(&rs1), WeightKind::GUARD);
636
637 let rs1 = rs_builder().set_flags(RelayFlags::V2DIR).build().unwrap();
638 assert_eq!(WeightKind::for_rs(&rs1), WeightKind::DIR);
639
640 let rs1 = rs_builder().build().unwrap();
641 assert_eq!(WeightKind::for_rs(&rs1), WeightKind::empty());
642
643 let rs1 = rs_builder().set_flags(RelayFlags::all()).build().unwrap();
644 assert_eq!(
645 WeightKind::for_rs(&rs1),
646 WeightKind::EXIT | WeightKind::GUARD | WeightKind::DIR
647 );
648 }
649
650 #[test]
651 fn weightset_from_consensus() {
652 use rand::Rng;
653 let now = SystemTime::now();
654 let one_hour = Duration::new(3600, 0);
655 let mut rng = testing_rng();
656 let mut bld = MdConsensus::builder();
657 bld.consensus_method(34)
658 .lifetime(Lifetime::new(now, now + one_hour, now + 2 * one_hour).unwrap())
659 .weights(TESTVEC_PARAMS.parse().unwrap());
660
661 for _ in 0..10 {
664 rs_builder()
665 .identity(rng.random::<[u8; 20]>().into()) .weight(RW::Unmeasured(1_000_000))
667 .set_flags(RelayFlags::GUARD | RelayFlags::EXIT)
668 .build_into(&mut bld)
669 .unwrap();
670 }
671 for n in 0..30 {
672 rs_builder()
673 .identity(rng.random::<[u8; 20]>().into()) .weight(RW::Measured(1_000 * n))
675 .set_flags(RelayFlags::GUARD | RelayFlags::EXIT)
676 .build_into(&mut bld)
677 .unwrap();
678 }
679
680 let consensus = bld.testing_consensus().unwrap();
681 let params = NetParameters::default();
682 let ws = WeightSet::from_consensus(&consensus, ¶ms);
683
684 assert_eq!(ws.bandwidth_fn, BandwidthFn::MeasuredOnly);
685 assert_eq!(ws.shift, 0);
686 assert_eq!(ws.w[0].as_guard, 5904);
687 assert_eq!(ws.w[5].as_guard, 5904);
688 assert_eq!(ws.w[5].as_middle, 4096);
689 }
690}