1use crate::ConsensusRelays;
16use crate::params::NetParameters;
17use bitflags::bitflags;
18use tor_netdoc::doc::netstatus::{self, MdConsensus, MdRouterStatus, NetParams};
19
20fn pick_bandwidth_fn<'a, I>(mut weights: I) -> BandwidthFn
23where
24 I: Clone + Iterator<Item = &'a netstatus::RelayWeight>,
25{
26 let has_measured = weights.clone().any(|w| w.is_measured());
27 let has_nonzero = weights.clone().any(|w| w.is_nonzero());
28 let has_nonzero_measured = weights.any(|w| w.is_measured() && w.is_nonzero());
29
30 if !has_nonzero {
31 BandwidthFn::Uniform
34 } else if !has_measured {
35 BandwidthFn::IncludeUnmeasured
38 } else if has_nonzero_measured {
39 BandwidthFn::MeasuredOnly
42 } else {
43 BandwidthFn::Uniform
47 }
48}
49
50#[derive(Copy, Clone, Debug, PartialEq, Eq)]
54enum BandwidthFn {
55 Uniform,
58 IncludeUnmeasured,
61 MeasuredOnly,
63}
64
65impl BandwidthFn {
66 fn apply(&self, w: &netstatus::RelayWeight) -> u32 {
69 use BandwidthFn::*;
70 use netstatus::RelayWeight::*;
71 match (self, w) {
72 (Uniform, _) => 1,
73 (IncludeUnmeasured, Unmeasured(u)) => *u,
74 (IncludeUnmeasured, Measured(m)) => *m,
75 (MeasuredOnly, Unmeasured(_)) => 0,
76 (MeasuredOnly, Measured(m)) => *m,
77 (_, _) => 0,
78 }
79 }
80}
81
82#[derive(Clone, Debug, Copy)]
89#[non_exhaustive]
90pub enum WeightRole {
91 Guard,
93 Middle,
95 Exit,
97 BeginDir,
99 Unweighted,
101 HsIntro,
103 HsRend,
105}
106
107#[derive(Clone, Debug, Copy)]
109struct RelayWeight {
110 as_guard: u32,
112 as_middle: u32,
114 as_exit: u32,
116 as_dir: u32,
118}
119
120impl std::ops::Mul<u32> for RelayWeight {
121 type Output = Self;
122 fn mul(self, rhs: u32) -> Self {
123 RelayWeight {
124 as_guard: self.as_guard * rhs,
125 as_middle: self.as_middle * rhs,
126 as_exit: self.as_exit * rhs,
127 as_dir: self.as_dir * rhs,
128 }
129 }
130}
131impl std::ops::Div<u32> for RelayWeight {
132 type Output = Self;
133 fn div(self, rhs: u32) -> Self {
134 RelayWeight {
135 as_guard: self.as_guard / rhs,
136 as_middle: self.as_middle / rhs,
137 as_exit: self.as_exit / rhs,
138 as_dir: self.as_dir / rhs,
139 }
140 }
141}
142
143impl RelayWeight {
144 #[allow(clippy::unwrap_used)]
147 fn max_weight(&self) -> u32 {
148 [self.as_guard, self.as_middle, self.as_exit, self.as_dir]
149 .iter()
150 .max()
151 .copied()
152 .unwrap()
153 }
154 fn for_role(&self, role: WeightRole) -> u32 {
157 match role {
158 WeightRole::Guard => self.as_guard,
159 WeightRole::Middle => self.as_middle,
160 WeightRole::Exit => self.as_exit,
161 WeightRole::BeginDir => self.as_dir,
162 WeightRole::HsIntro => self.as_middle, WeightRole::HsRend => self.as_middle, WeightRole::Unweighted => 1,
165 }
166 }
167}
168
169bitflags! {
170 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
175 struct WeightKind: u8 {
176 const GUARD = 1 << 0;
178 const EXIT = 1 << 1;
180 const DIR = 1 << 2;
182 }
183}
184
185impl WeightKind {
186 fn for_rs(rs: &MdRouterStatus) -> Self {
188 let mut r = WeightKind::empty();
189 if rs.is_flagged_guard() {
190 r |= WeightKind::GUARD;
191 }
192 if rs.is_flagged_exit() {
193 r |= WeightKind::EXIT;
194 }
195 if rs.is_flagged_v2dir() {
196 r |= WeightKind::DIR;
197 }
198 r
199 }
200 fn idx(self) -> usize {
202 self.bits() as usize
203 }
204}
205
206#[derive(Debug, Clone)]
209pub(crate) struct WeightSet {
210 bandwidth_fn: BandwidthFn,
216 shift: u8,
227 w: [RelayWeight; 8],
230}
231
232impl WeightSet {
233 pub(crate) fn weight_rs_for_role(&self, rs: &MdRouterStatus, role: WeightRole) -> u64 {
240 self.weight_bw_for_role(WeightKind::for_rs(rs), rs.weight(), role)
241 }
242
243 fn weight_bw_for_role(
246 &self,
247 kind: WeightKind,
248 relay_weight: &netstatus::RelayWeight,
249 role: WeightRole,
250 ) -> u64 {
251 let ws = &self.w[kind.idx()];
252
253 let router_bw = self.bandwidth_fn.apply(relay_weight);
254 let router_weight = u64::from(router_bw) * u64::from(ws.for_role(role));
259 router_weight >> self.shift
260 }
261
262 pub(crate) fn from_consensus(consensus: &MdConsensus, params: &NetParameters) -> Self {
264 let bandwidth_fn = pick_bandwidth_fn(consensus.c_relays().iter().map(|rs| rs.weight()));
265 let weight_scale = params.bw_weight_scale.into();
266
267 let total_bw = consensus
268 .c_relays()
269 .iter()
270 .map(|rs| u64::from(bandwidth_fn.apply(rs.weight())))
271 .sum();
272 let p = consensus.bandwidth_weights();
273
274 Self::from_parts(bandwidth_fn, total_bw, weight_scale, p).validate(consensus)
275 }
276
277 fn from_parts(
281 bandwidth_fn: BandwidthFn,
282 total_bw: u64,
283 weight_scale: u32,
284 p: &NetParams<i32>,
285 ) -> Self {
286 #[allow(clippy::many_single_char_names)]
292 fn single(p: &NetParams<i32>, g: &str, m: &str, e: &str, d: &str) -> RelayWeight {
293 RelayWeight {
294 as_guard: w_param(p, g),
295 as_middle: w_param(p, m),
296 as_exit: w_param(p, e),
297 as_dir: w_param(p, d),
298 }
299 }
300
301 let weight_scale = weight_scale.max(1);
304
305 let w_none = single(p, "Wgm", "Wmm", "Wem", "Wbm");
311 let w_guard = single(p, "Wgg", "Wmg", "Weg", "Wbg");
312 let w_exit = single(p, "---", "Wme", "Wee", "Wbe");
313 let w_both = single(p, "Wgd", "Wmd", "Wed", "Wbd");
314
315 let w = [
318 w_none,
319 w_guard,
320 w_exit,
321 w_both,
322 (w_none * w_param(p, "Wmb")) / weight_scale,
328 (w_guard * w_param(p, "Wgb")) / weight_scale,
329 (w_exit * w_param(p, "Web")) / weight_scale,
330 (w_both * w_param(p, "Wdb")) / weight_scale,
331 ];
332
333 #[allow(clippy::unwrap_used)]
336 let w_max = w.iter().map(RelayWeight::max_weight).max().unwrap();
337
338 let shift = calculate_shift(total_bw, u64::from(w_max)) as u8;
340
341 WeightSet {
342 bandwidth_fn,
343 shift,
344 w,
345 }
346 }
347
348 fn validate(self, consensus: &MdConsensus) -> Self {
351 use WeightRole::*;
352 for role in [Guard, Middle, Exit, BeginDir, Unweighted] {
353 let _: u64 = consensus
354 .c_relays()
355 .iter()
356 .map(|rs| self.weight_rs_for_role(rs, role))
357 .fold(0_u64, |a, b| {
358 a.checked_add(b)
359 .expect("Incorrect relay weight calculation: total exceeded u64::MAX!")
360 });
361 }
362 self
363 }
364}
365
366const DFLT_WEIGHT: i32 = 1;
373
374fn w_param(p: &NetParams<i32>, kwd: &str) -> u32 {
379 if kwd == "---" {
380 0
381 } else {
382 clamp_to_pos(*p.get(kwd).unwrap_or(&DFLT_WEIGHT))
383 }
384}
385
386fn clamp_to_pos(inp: i32) -> u32 {
388 if inp < 0 { 0 } else { inp as u32 }
391}
392
393fn calculate_shift(a: u64, b: u64) -> u32 {
396 let bits_for_product = log2_upper(a) + log2_upper(b);
397 bits_for_product.saturating_sub(64)
398}
399
400fn log2_upper(n: u64) -> u32 {
405 64 - n.leading_zeros()
406}
407
408#[cfg(test)]
409mod test {
410 #![allow(clippy::bool_assert_comparison)]
412 #![allow(clippy::clone_on_copy)]
413 #![allow(clippy::dbg_macro)]
414 #![allow(clippy::mixed_attributes_style)]
415 #![allow(clippy::print_stderr)]
416 #![allow(clippy::print_stdout)]
417 #![allow(clippy::single_char_pattern)]
418 #![allow(clippy::unwrap_used)]
419 #![allow(clippy::unchecked_time_subtraction)]
420 #![allow(clippy::useless_vec)]
421 #![allow(clippy::needless_pass_by_value)]
422 use super::*;
424 use netstatus::RelayWeight as RW;
425 use std::net::SocketAddr;
426 use std::time::{Duration, SystemTime};
427 use tor_basic_utils::test_rng::testing_rng;
428 use tor_netdoc::doc::netstatus::{Lifetime, MdRouterStatusBuilder};
429 use tor_netdoc::types::relay_flags::{RelayFlag, RelayFlags};
430 use web_time_compat::SystemTimeExt;
431
432 #[test]
433 fn t_clamp() {
434 assert_eq!(clamp_to_pos(32), 32);
435 assert_eq!(clamp_to_pos(i32::MAX), i32::MAX as u32);
436 assert_eq!(clamp_to_pos(0), 0);
437 assert_eq!(clamp_to_pos(-1), 0);
438 assert_eq!(clamp_to_pos(i32::MIN), 0);
439 }
440
441 #[test]
442 fn t_log2() {
443 assert_eq!(log2_upper(u64::MAX), 64);
444 assert_eq!(log2_upper(0), 0);
445 assert_eq!(log2_upper(1), 1);
446 assert_eq!(log2_upper(63), 6);
447 assert_eq!(log2_upper(64), 7); }
449
450 #[test]
451 fn t_calc_shift() {
452 assert_eq!(calculate_shift(1 << 20, 1 << 20), 0);
453 assert_eq!(calculate_shift(1 << 50, 1 << 10), 0);
454 assert_eq!(calculate_shift(1 << 32, 1 << 33), 3);
455 assert!(((1_u64 << 32) >> 3).checked_mul(1_u64 << 33).is_some());
456 assert_eq!(calculate_shift(432 << 40, 7777 << 40), 38);
457 assert!(
458 ((432_u64 << 40) >> 38)
459 .checked_mul(7777_u64 << 40)
460 .is_some()
461 );
462 }
463
464 #[test]
465 fn t_pick_bwfunc() {
466 let empty = [];
467 assert_eq!(pick_bandwidth_fn(empty.iter()), BandwidthFn::Uniform);
468
469 let all_zero = [RW::Unmeasured(0), RW::Measured(0), RW::Unmeasured(0)];
470 assert_eq!(pick_bandwidth_fn(all_zero.iter()), BandwidthFn::Uniform);
471
472 let all_unmeasured = [RW::Unmeasured(9), RW::Unmeasured(2222)];
473 assert_eq!(
474 pick_bandwidth_fn(all_unmeasured.iter()),
475 BandwidthFn::IncludeUnmeasured
476 );
477
478 let some_measured = [
479 RW::Unmeasured(10),
480 RW::Measured(7),
481 RW::Measured(4),
482 RW::Unmeasured(0),
483 ];
484 assert_eq!(
485 pick_bandwidth_fn(some_measured.iter()),
486 BandwidthFn::MeasuredOnly
487 );
488
489 let measured_all_zero = [RW::Unmeasured(10), RW::Measured(0)];
493 assert_eq!(
494 pick_bandwidth_fn(measured_all_zero.iter()),
495 BandwidthFn::Uniform
496 );
497 }
498
499 #[test]
500 fn t_apply_bwfn() {
501 use BandwidthFn::*;
502 use netstatus::RelayWeight::*;
503
504 assert_eq!(Uniform.apply(&Measured(7)), 1);
505 assert_eq!(Uniform.apply(&Unmeasured(0)), 1);
506
507 assert_eq!(IncludeUnmeasured.apply(&Measured(7)), 7);
508 assert_eq!(IncludeUnmeasured.apply(&Unmeasured(8)), 8);
509
510 assert_eq!(MeasuredOnly.apply(&Measured(9)), 9);
511 assert_eq!(MeasuredOnly.apply(&Unmeasured(10)), 0);
512 }
513
514 const TESTVEC_PARAMS: &str = "Wbd=0 Wbe=0 Wbg=4096 Wbm=10000 Wdb=10000 Web=10000 Wed=10000 Wee=10000 Weg=10000 Wem=10000 Wgb=10000 Wgd=0 Wgg=5904 Wgm=5904 Wmb=10000 Wmd=0 Wme=0 Wmg=4096 Wmm=10000";
516
517 #[test]
518 fn t_weightset_basic() {
519 let total_bandwidth = 1_000_000_000;
520 let params = TESTVEC_PARAMS.parse().unwrap();
521 let ws = WeightSet::from_parts(BandwidthFn::MeasuredOnly, total_bandwidth, 10000, ¶ms);
522
523 assert_eq!(ws.bandwidth_fn, BandwidthFn::MeasuredOnly);
524 assert_eq!(ws.shift, 0);
525
526 assert_eq!(ws.w[0].as_guard, 5904);
527 assert_eq!(ws.w[(WeightKind::GUARD.bits()) as usize].as_guard, 5904);
528 assert_eq!(ws.w[(WeightKind::EXIT.bits()) as usize].as_exit, 10000);
529 assert_eq!(
530 ws.w[(WeightKind::EXIT | WeightKind::GUARD).bits() as usize].as_dir,
531 0
532 );
533 assert_eq!(
534 ws.w[(WeightKind::GUARD | WeightKind::DIR).bits() as usize].as_dir,
535 4096
536 );
537 assert_eq!(
538 ws.w[(WeightKind::GUARD | WeightKind::DIR).bits() as usize].as_dir,
539 4096
540 );
541
542 assert_eq!(
543 ws.weight_bw_for_role(
544 WeightKind::GUARD | WeightKind::DIR,
545 &RW::Unmeasured(7777),
546 WeightRole::Guard
547 ),
548 0
549 );
550
551 assert_eq!(
552 ws.weight_bw_for_role(
553 WeightKind::GUARD | WeightKind::DIR,
554 &RW::Measured(7777),
555 WeightRole::Guard
556 ),
557 7777 * 5904
558 );
559
560 assert_eq!(
561 ws.weight_bw_for_role(
562 WeightKind::GUARD | WeightKind::DIR,
563 &RW::Measured(7777),
564 WeightRole::Middle
565 ),
566 7777 * 4096
567 );
568
569 assert_eq!(
570 ws.weight_bw_for_role(
571 WeightKind::GUARD | WeightKind::DIR,
572 &RW::Measured(7777),
573 WeightRole::Exit
574 ),
575 7777 * 10000
576 );
577
578 assert_eq!(
579 ws.weight_bw_for_role(
580 WeightKind::GUARD | WeightKind::DIR,
581 &RW::Measured(7777),
582 WeightRole::BeginDir
583 ),
584 7777 * 4096
585 );
586
587 assert_eq!(
588 ws.weight_bw_for_role(
589 WeightKind::GUARD | WeightKind::DIR,
590 &RW::Measured(7777),
591 WeightRole::Unweighted
592 ),
593 7777
594 );
595
596 let rs = rs_builder()
598 .set_flags(RelayFlag::Guard | RelayFlag::V2Dir)
599 .weight(RW::Measured(7777))
600 .build()
601 .unwrap();
602 assert_eq!(ws.weight_rs_for_role(&rs, WeightRole::Exit), 7777 * 10000);
603 assert_eq!(
604 ws.weight_rs_for_role(&rs, WeightRole::BeginDir),
605 7777 * 4096
606 );
607 assert_eq!(ws.weight_rs_for_role(&rs, WeightRole::Unweighted), 7777);
608 }
609
610 fn rs_builder() -> MdRouterStatusBuilder {
613 MdConsensus::builder()
614 .rs()
615 .identity([9; 20].into())
616 .add_or_port(SocketAddr::from(([127, 0, 0, 1], 9001)))
617 .doc_digest([9; 32])
618 .protos("".parse().unwrap())
619 .clone()
620 }
621
622 #[test]
623 fn weight_flags() {
624 let rs1 = rs_builder().set_flags(RelayFlag::Exit).build().unwrap();
625 assert_eq!(WeightKind::for_rs(&rs1), WeightKind::EXIT);
626
627 let rs1 = rs_builder().set_flags(RelayFlag::Guard).build().unwrap();
628 assert_eq!(WeightKind::for_rs(&rs1), WeightKind::GUARD);
629
630 let rs1 = rs_builder().set_flags(RelayFlag::V2Dir).build().unwrap();
631 assert_eq!(WeightKind::for_rs(&rs1), WeightKind::DIR);
632
633 let rs1 = rs_builder().build().unwrap();
634 assert_eq!(WeightKind::for_rs(&rs1), WeightKind::empty());
635
636 let rs1 = rs_builder().set_flags(RelayFlags::all()).build().unwrap();
637 assert_eq!(
638 WeightKind::for_rs(&rs1),
639 WeightKind::EXIT | WeightKind::GUARD | WeightKind::DIR
640 );
641 }
642
643 #[test]
644 fn weightset_from_consensus() {
645 use rand::Rng;
646 let now = SystemTime::get();
647 let one_hour = Duration::new(3600, 0);
648 let mut rng = testing_rng();
649 let mut bld = MdConsensus::builder();
650 bld.consensus_method(34)
651 .lifetime(Lifetime::new(now, now + one_hour, now + 2 * one_hour).unwrap())
652 .weights(TESTVEC_PARAMS.parse().unwrap());
653
654 for _ in 0..10 {
657 rs_builder()
658 .identity(rng.random::<[u8; 20]>().into()) .weight(RW::Unmeasured(1_000_000))
660 .set_flags(RelayFlag::Guard | RelayFlag::Exit)
661 .build_into(&mut bld)
662 .unwrap();
663 }
664 for n in 0..30 {
665 rs_builder()
666 .identity(rng.random::<[u8; 20]>().into()) .weight(RW::Measured(1_000 * n))
668 .set_flags(RelayFlag::Guard | RelayFlag::Exit)
669 .build_into(&mut bld)
670 .unwrap();
671 }
672
673 let consensus = bld.testing_consensus().unwrap();
674 let params = NetParameters::default();
675 let ws = WeightSet::from_consensus(&consensus, ¶ms);
676
677 assert_eq!(ws.bandwidth_fn, BandwidthFn::MeasuredOnly);
678 assert_eq!(ws.shift, 0);
679 assert_eq!(ws.w[0].as_guard, 5904);
680 assert_eq!(ws.w[5].as_guard, 5904);
681 assert_eq!(ws.w[5].as_middle, 4096);
682 }
683}