trippy_core/
state.rs

1use crate::config::StateConfig;
2use crate::constants::MAX_TTL;
3use crate::flows::{Flow, FlowId, FlowRegistry};
4use crate::{
5    Dscp, Ecn, Extensions, IcmpPacketType, ProbeStatus, Round, RoundId, TimeToLive, TypeOfService,
6};
7use indexmap::IndexMap;
8use std::collections::HashMap;
9use std::iter::once;
10use std::net::IpAddr;
11use std::time::Duration;
12use tracing::instrument;
13
14/// The state of a trace.
15#[derive(Debug, Clone, Default)]
16pub struct State {
17    /// The configuration for the state.
18    state_config: StateConfig,
19    /// The flow id for the current round.
20    round_flow_id: FlowId,
21    /// Tracing state per registered flow id.
22    state: HashMap<FlowId, FlowState>,
23    /// Flow registry.
24    registry: FlowRegistry,
25    /// Tracing error message.
26    error: Option<String>,
27}
28
29impl State {
30    /// Create a new `State`.
31    #[must_use]
32    pub fn new(state_config: StateConfig) -> Self {
33        Self {
34            state: once((
35                Self::default_flow_id(),
36                FlowState::new(state_config.max_samples),
37            ))
38            .collect::<HashMap<FlowId, FlowState>>(),
39            round_flow_id: Self::default_flow_id(),
40            state_config,
41            registry: FlowRegistry::new(),
42            error: None,
43        }
44    }
45
46    /// Return the id of the default flow.
47    #[must_use]
48    pub const fn default_flow_id() -> FlowId {
49        FlowId(0)
50    }
51
52    /// Information about each hop for the combined default flow.
53    #[must_use]
54    pub fn hops(&self) -> &[Hop] {
55        self.state[&Self::default_flow_id()].hops()
56    }
57
58    /// Information about each hop for a given flow.
59    #[must_use]
60    pub fn hops_for_flow(&self, flow_id: FlowId) -> &[Hop] {
61        self.state[&flow_id].hops()
62    }
63
64    /// Is a given `Hop` the target hop for a given flow?
65    ///
66    /// A `Hop` is considered to be the target if it has the highest `ttl` value observed.
67    ///
68    /// Note that if the target host does not respond to probes then the highest `ttl` observed
69    /// will be one greater than the `ttl` of the last host which did respond.
70    #[must_use]
71    pub fn is_target(&self, hop: &Hop, flow_id: FlowId) -> bool {
72        self.state[&flow_id].is_target(hop)
73    }
74
75    /// Is a given `Hop` in the current round for a given flow?
76    #[must_use]
77    pub fn is_in_round(&self, hop: &Hop, flow_id: FlowId) -> bool {
78        self.state[&flow_id].is_in_round(hop)
79    }
80
81    /// Return the target `Hop` for a given flow.
82    #[must_use]
83    pub fn target_hop(&self, flow_id: FlowId) -> &Hop {
84        self.state[&flow_id].target_hop()
85    }
86
87    /// The current round of tracing for a given flow.
88    #[must_use]
89    pub fn round(&self, flow_id: FlowId) -> Option<usize> {
90        self.state[&flow_id].round()
91    }
92
93    /// The total rounds of tracing for a given flow.
94    #[must_use]
95    pub fn round_count(&self, flow_id: FlowId) -> usize {
96        self.state[&flow_id].round_count()
97    }
98
99    /// The `FlowId` for the current round.
100    #[must_use]
101    pub const fn round_flow_id(&self) -> FlowId {
102        self.round_flow_id
103    }
104
105    /// The registry of flows in the trace.
106    #[must_use]
107    pub fn flows(&self) -> &[(Flow, FlowId)] {
108        self.registry.flows()
109    }
110
111    /// The error message for the trace, if any.
112    #[must_use]
113    pub fn error(&self) -> Option<&str> {
114        self.error.as_deref()
115    }
116
117    pub fn set_error(&mut self, error: Option<String>) {
118        self.error = error;
119    }
120
121    /// The maximum number of samples to record per hop.
122    #[must_use]
123    pub const fn max_samples(&self) -> usize {
124        self.state_config.max_samples
125    }
126
127    /// The maximum number of flows to record.
128    #[must_use]
129    pub const fn max_flows(&self) -> usize {
130        self.state_config.max_flows
131    }
132
133    /// Update the tracing state from a `TracerRound`.
134    #[instrument(skip(self, round), level = "trace")]
135    pub fn update_from_round(&mut self, round: &Round<'_>) {
136        let flow = Flow::from_hops(
137            round
138                .probes
139                .iter()
140                .filter_map(|probe| match probe {
141                    ProbeStatus::Awaited(_) => Some(None),
142                    ProbeStatus::Complete(completed) => Some(Some(completed.host)),
143                    _ => None,
144                })
145                .take(usize::from(round.largest_ttl.0)),
146        );
147        self.update_trace_flow(Self::default_flow_id(), round);
148        if self.registry.flows().len() < self.state_config.max_flows {
149            let flow_id = self.registry.register(flow);
150            self.round_flow_id = flow_id;
151            self.update_trace_flow(flow_id, round);
152        }
153    }
154
155    #[instrument(skip(self, round), level = "trace")]
156    fn update_trace_flow(&mut self, flow_id: FlowId, round: &Round<'_>) {
157        let flow_trace = self
158            .state
159            .entry(flow_id)
160            .or_insert_with(|| FlowState::new(self.state_config.max_samples));
161        flow_trace.update_from_round(round);
162    }
163}
164
165/// Information about a single `Hop` within a `Trace`.
166#[derive(Debug, Clone)]
167pub struct Hop {
168    /// The ttl of this hop.
169    ttl: u8,
170    /// The addrs of this hop and associated counts.
171    addrs: IndexMap<IpAddr, usize>,
172    /// The total probes sent for this hop.
173    total_sent: usize,
174    /// The total probes received for this hop.
175    total_recv: usize,
176    /// The total probes that failed for this hop.
177    total_failed: usize,
178    /// The total forward loss for this hop.
179    total_forward_lost: usize,
180    /// The total backward loss for this hop.
181    total_backward_lost: usize,
182    /// The total round trip time for this hop across all rounds.
183    total_time: Duration,
184    /// The round trip time for this hop in the current round.
185    last: Option<Duration>,
186    /// The best round trip time for this hop across all rounds.
187    best: Option<Duration>,
188    /// The worst round trip time for this hop across all rounds.
189    worst: Option<Duration>,
190    /// The current jitter i.e. round-trip difference with the last round-trip.
191    jitter: Option<Duration>,
192    /// The average jitter time for all probes at this hop.
193    javg: f64,
194    /// The worst round-trip jitter time for all probes at this hop.
195    jmax: Option<Duration>,
196    /// The smoothed jitter value for all probes at this hop.
197    jinta: f64,
198    /// The source port for last probe for this hop.
199    last_src_port: u16,
200    /// The destination port for last probe for this hop.
201    last_dest_port: u16,
202    /// The sequence number for the last probe for this hop.
203    last_sequence: u16,
204    /// The icmp packet type for the last probe for this hop.
205    last_icmp_packet_type: Option<IcmpPacketType>,
206    /// The NAT detection status for the last probe for this hop.
207    last_nat_status: NatStatus,
208    /// The history of round trip times across the last N rounds.
209    samples: Vec<Duration>,
210    /// The type of service (DSCP/ECN) for this hop.
211    tos: Option<TypeOfService>,
212    /// The ICMP extensions for this hop.
213    extensions: Option<Extensions>,
214    mean: f64,
215    m2: f64,
216}
217
218impl Hop {
219    /// The time-to-live of this hop.
220    #[must_use]
221    pub const fn ttl(&self) -> u8 {
222        self.ttl
223    }
224
225    /// The set of addresses that have responded for this time-to-live.
226    pub fn addrs(&self) -> impl Iterator<Item = &IpAddr> {
227        self.addrs.keys()
228    }
229
230    pub fn addrs_with_counts(&self) -> impl Iterator<Item = (&IpAddr, &usize)> {
231        self.addrs.iter()
232    }
233
234    /// The number of unique address observed for this time-to-live.
235    #[must_use]
236    pub fn addr_count(&self) -> usize {
237        self.addrs.len()
238    }
239
240    /// The total number of probes sent.
241    #[must_use]
242    pub const fn total_sent(&self) -> usize {
243        self.total_sent
244    }
245
246    /// The total number of probes responses received.
247    #[must_use]
248    pub const fn total_recv(&self) -> usize {
249        self.total_recv
250    }
251
252    /// The total number of probes with forward loss.
253    #[must_use]
254    pub const fn total_forward_loss(&self) -> usize {
255        self.total_forward_lost
256    }
257
258    /// The total number of probes with backward loss.
259    #[must_use]
260    pub const fn total_backward_loss(&self) -> usize {
261        self.total_backward_lost
262    }
263
264    /// The total number of probes that failed.
265    #[must_use]
266    pub const fn total_failed(&self) -> usize {
267        self.total_failed
268    }
269
270    /// The % of packets that are lost.
271    #[must_use]
272    pub fn loss_pct(&self) -> f64 {
273        if self.total_sent > 0 {
274            let lost = self.total_sent - self.total_recv;
275            lost as f64 / self.total_sent as f64 * 100_f64
276        } else {
277            0_f64
278        }
279    }
280
281    /// The % of packets that are lost forward.
282    #[must_use]
283    pub fn forward_loss_pct(&self) -> f64 {
284        if self.total_sent > 0 {
285            let lost = self.total_forward_lost;
286            lost as f64 / self.total_sent as f64 * 100_f64
287        } else {
288            0_f64
289        }
290    }
291
292    /// The % of packets that are lost backward.
293    #[must_use]
294    pub fn backward_loss_pct(&self) -> f64 {
295        if self.total_sent > 0 {
296            let lost = self.total_backward_lost;
297            lost as f64 / self.total_sent as f64 * 100_f64
298        } else {
299            0_f64
300        }
301    }
302
303    /// The duration of the last probe.
304    #[must_use]
305    pub fn last_ms(&self) -> Option<f64> {
306        self.last.map(|last| last.as_secs_f64() * 1000_f64)
307    }
308
309    /// The duration of the best probe observed.
310    #[must_use]
311    pub fn best_ms(&self) -> Option<f64> {
312        self.best.map(|last| last.as_secs_f64() * 1000_f64)
313    }
314
315    /// The duration of the worst probe observed.
316    #[must_use]
317    pub fn worst_ms(&self) -> Option<f64> {
318        self.worst.map(|last| last.as_secs_f64() * 1000_f64)
319    }
320
321    /// The average duration of all probes.
322    #[must_use]
323    pub fn avg_ms(&self) -> f64 {
324        if self.total_recv() > 0 {
325            (self.total_time.as_secs_f64() * 1000_f64) / self.total_recv as f64
326        } else {
327            0_f64
328        }
329    }
330
331    /// The standard deviation of all probes.
332    #[must_use]
333    pub fn stddev_ms(&self) -> f64 {
334        if self.total_recv > 1 {
335            (self.m2 / (self.total_recv - 1) as f64).sqrt()
336        } else {
337            0_f64
338        }
339    }
340
341    /// The duration of the jitter probe observed.
342    #[must_use]
343    pub fn jitter_ms(&self) -> Option<f64> {
344        self.jitter.map(|j| j.as_secs_f64() * 1000_f64)
345    }
346
347    /// The duration of the worst probe observed.
348    #[must_use]
349    pub fn jmax_ms(&self) -> Option<f64> {
350        self.jmax.map(|x| x.as_secs_f64() * 1000_f64)
351    }
352
353    /// The jitter average duration of all probes.
354    #[must_use]
355    pub const fn javg_ms(&self) -> f64 {
356        self.javg
357    }
358
359    /// The jitter interval of all probes.
360    #[must_use]
361    pub const fn jinta(&self) -> f64 {
362        self.jinta
363    }
364
365    /// The source port for last probe for this hop.
366    #[must_use]
367    pub const fn last_src_port(&self) -> u16 {
368        self.last_src_port
369    }
370
371    /// The destination port for last probe for this hop.
372    #[must_use]
373    pub const fn last_dest_port(&self) -> u16 {
374        self.last_dest_port
375    }
376
377    /// The sequence number for the last probe for this hop.
378    #[must_use]
379    pub const fn last_sequence(&self) -> u16 {
380        self.last_sequence
381    }
382
383    /// The icmp packet type for the last probe for this hop.
384    #[must_use]
385    pub const fn last_icmp_packet_type(&self) -> Option<IcmpPacketType> {
386        self.last_icmp_packet_type
387    }
388
389    /// The NAT detection status for the last probe for this hop.
390    #[must_use]
391    pub const fn last_nat_status(&self) -> NatStatus {
392        self.last_nat_status
393    }
394
395    /// The type of service (DSCP/ECN) for this hop.
396    #[must_use]
397    pub fn tos(&self) -> Option<TypeOfService> {
398        self.tos
399    }
400
401    /// The `DSCP` for this hop.
402    #[must_use]
403    pub fn dscp(&self) -> Option<Dscp> {
404        self.tos.map(|tos| tos.dscp())
405    }
406
407    /// The `ECN` for this hop.
408    #[must_use]
409    pub fn ecn(&self) -> Option<Ecn> {
410        self.tos.map(|tos| tos.ecn())
411    }
412
413    /// The last N samples.
414    #[must_use]
415    pub fn samples(&self) -> &[Duration] {
416        &self.samples
417    }
418
419    #[must_use]
420    pub const fn extensions(&self) -> Option<&Extensions> {
421        self.extensions.as_ref()
422    }
423}
424
425impl Default for Hop {
426    fn default() -> Self {
427        Self {
428            ttl: 0,
429            addrs: IndexMap::default(),
430            total_sent: 0,
431            total_recv: 0,
432            total_forward_lost: 0,
433            total_backward_lost: 0,
434            total_failed: 0,
435            total_time: Duration::default(),
436            last: None,
437            best: None,
438            worst: None,
439            jitter: None,
440            javg: 0f64,
441            jmax: None,
442            jinta: 0f64,
443            last_src_port: 0_u16,
444            last_dest_port: 0_u16,
445            last_sequence: 0_u16,
446            last_icmp_packet_type: None,
447            mean: 0f64,
448            m2: 0f64,
449            samples: Vec::default(),
450            tos: None,
451            extensions: None,
452            last_nat_status: NatStatus::NotApplicable,
453        }
454    }
455}
456
457/// The state of a NAT detection for a `Hop`.
458#[derive(Debug, Copy, Clone, Eq, PartialEq)]
459pub enum NatStatus {
460    /// NAT detection was not applicable.
461    NotApplicable,
462    /// NAT was not detected at this hop.
463    NotDetected,
464    /// NAT was detected at this hop.
465    Detected,
466}
467
468/// Data for a single trace flow.
469#[derive(Debug, Clone)]
470struct FlowState {
471    /// The maximum number of samples to record.
472    max_samples: usize,
473    /// The lowest ttl observed across all rounds.
474    lowest_ttl: u8,
475    /// The highest ttl observed across all rounds.
476    highest_ttl: u8,
477    /// The highest ttl observed for the latest round.
478    highest_ttl_for_round: u8,
479    /// The latest round received.
480    round: Option<usize>,
481    /// The total number of rounds received.
482    round_count: usize,
483    /// The hops in this trace.
484    hops: Vec<Hop>,
485}
486
487impl FlowState {
488    fn new(max_samples: usize) -> Self {
489        Self {
490            max_samples,
491            lowest_ttl: 0,
492            highest_ttl: 0,
493            highest_ttl_for_round: 0,
494            round: None,
495            round_count: 0,
496            hops: (0..MAX_TTL).map(|_| Hop::default()).collect(),
497        }
498    }
499
500    fn hops(&self) -> &[Hop] {
501        if self.lowest_ttl == 0 || self.highest_ttl == 0 {
502            &[]
503        } else {
504            let start = (self.lowest_ttl as usize) - 1;
505            let end = self.highest_ttl as usize;
506            &self.hops[start..end]
507        }
508    }
509
510    const fn is_target(&self, hop: &Hop) -> bool {
511        self.highest_ttl_for_round == hop.ttl
512    }
513
514    const fn is_in_round(&self, hop: &Hop) -> bool {
515        hop.ttl <= self.highest_ttl_for_round
516    }
517
518    fn target_hop(&self) -> &Hop {
519        if self.highest_ttl_for_round > 0 {
520            &self.hops[usize::from(self.highest_ttl_for_round) - 1]
521        } else {
522            &self.hops[0]
523        }
524    }
525
526    const fn round(&self) -> Option<usize> {
527        self.round
528    }
529
530    const fn round_count(&self) -> usize {
531        self.round_count
532    }
533
534    fn update_from_round(&mut self, round: &Round<'_>) {
535        state_updater::StateUpdater::new(self, round).apply();
536    }
537
538    fn update_round(&mut self, round: RoundId) {
539        self.round = match self.round {
540            None => Some(round.0),
541            Some(r) => Some(r.max(round.0)),
542        }
543    }
544
545    fn update_lowest_ttl(&mut self, ttl: TimeToLive) {
546        if self.lowest_ttl == 0 {
547            self.lowest_ttl = ttl.0;
548        } else {
549            self.lowest_ttl = self.lowest_ttl.min(ttl.0);
550        }
551    }
552}
553
554mod state_updater {
555    use crate::state::FlowState;
556    use crate::types::Checksum;
557    use crate::{NatStatus, ProbeStatus, Round, TimeToLive};
558    use std::time::Duration;
559    use tracing::instrument;
560
561    /// Update the state of a `FlowState` from a `Round`.
562    pub(super) struct StateUpdater<'a> {
563        /// The state to update.
564        state: &'a mut FlowState,
565        /// The `Round` being processed.
566        round: &'a Round<'a>,
567        /// The checksum of the previous hop, if any.
568        prev_hop_checksum: Option<u16>,
569        /// Whether any previous hop in the round had forward loss.
570        forward_loss: bool,
571    }
572    impl<'a> StateUpdater<'a> {
573        pub(super) fn new(state: &'a mut FlowState, round: &'a Round<'_>) -> Self {
574            Self {
575                state,
576                round,
577                prev_hop_checksum: None,
578                forward_loss: false,
579            }
580        }
581
582        #[instrument(skip(self), level = "trace")]
583        pub(super) fn apply(&mut self) {
584            self.state.round_count += 1;
585            self.state.highest_ttl =
586                std::cmp::max(self.state.highest_ttl, self.round.largest_ttl.0);
587            self.state.highest_ttl_for_round = self.round.largest_ttl.0;
588            for probe in self.round.probes {
589                self.update_for_probe(probe);
590            }
591        }
592
593        #[instrument(skip(self), level = "trace")]
594        fn update_for_probe(&mut self, probe: &ProbeStatus) {
595            let state = &mut *self.state;
596            match probe {
597                ProbeStatus::Complete(complete) => {
598                    state.update_lowest_ttl(complete.ttl);
599                    state.update_round(complete.round);
600                    let index = usize::from(complete.ttl.0) - 1;
601                    let hop = &mut state.hops[index];
602                    hop.ttl = complete.ttl.0;
603                    hop.total_sent += 1;
604                    hop.total_recv += 1;
605                    let dur = complete
606                        .received
607                        .duration_since(complete.sent)
608                        .unwrap_or_default();
609                    let dur_ms = dur.as_secs_f64() * 1000_f64;
610                    hop.total_time += dur;
611                    // Before last is set use it to calc jitter
612                    let last_ms = hop.last_ms().unwrap_or_default();
613                    let jitter_ms = (dur_ms - last_ms).abs();
614                    let jitter_dur = Duration::from_secs_f64(jitter_ms / 1000_f64);
615                    hop.jitter = hop.last.and(Some(jitter_dur));
616                    hop.javg += (jitter_ms - hop.javg) / hop.total_recv as f64;
617                    // algorithm is from rfc1889, A.8 or rfc3550
618                    hop.jinta += jitter_ms.max(0.5) - ((hop.jinta + 8.0) / 16.0);
619                    hop.jmax = hop
620                        .jmax
621                        .map_or(Some(jitter_dur), |d| Some(d.max(jitter_dur)));
622                    hop.last = Some(dur);
623                    hop.samples.insert(0, dur);
624                    hop.best = hop.best.map_or(Some(dur), |d| Some(d.min(dur)));
625                    hop.worst = hop.worst.map_or(Some(dur), |d| Some(d.max(dur)));
626                    hop.mean += (dur_ms - hop.mean) / hop.total_recv as f64;
627                    hop.m2 += (dur_ms - hop.mean) * (dur_ms - hop.mean);
628                    if hop.samples.len() > state.max_samples {
629                        hop.samples.pop();
630                    }
631                    let host = complete.host;
632                    *hop.addrs.entry(host).or_default() += 1;
633                    hop.extensions.clone_from(&complete.extensions);
634                    hop.last_src_port = complete.src_port.0;
635                    hop.last_dest_port = complete.dest_port.0;
636                    hop.last_sequence = complete.sequence.0;
637                    hop.last_icmp_packet_type = Some(complete.icmp_packet_type);
638                    hop.tos = complete.tos;
639                    if let (Some(expected), Some(actual)) =
640                        (complete.expected_udp_checksum, complete.actual_udp_checksum)
641                    {
642                        let (nat_status, checksum) =
643                            nat_status(expected, actual, self.prev_hop_checksum);
644                        hop.last_nat_status = nat_status;
645                        self.prev_hop_checksum = Some(checksum);
646                    }
647                }
648                ProbeStatus::Awaited(awaited) => {
649                    state.update_lowest_ttl(awaited.ttl);
650                    state.update_round(awaited.round);
651                    let index = usize::from(awaited.ttl.0) - 1;
652                    let hop = &mut state.hops[index];
653                    hop.total_sent += 1;
654                    hop.ttl = awaited.ttl.0;
655                    hop.samples.insert(0, Duration::default());
656                    if hop.samples.len() > state.max_samples {
657                        hop.samples.pop();
658                    }
659                    hop.last_src_port = awaited.src_port.0;
660                    hop.last_dest_port = awaited.dest_port.0;
661                    hop.last_sequence = awaited.sequence.0;
662                    if self.forward_loss {
663                        hop.total_backward_lost += 1;
664                    } else if is_forward_loss(self.round.probes, awaited.ttl) {
665                        hop.total_forward_lost += 1;
666                        self.forward_loss = true;
667                    }
668                }
669                ProbeStatus::Failed(failed) => {
670                    state.update_lowest_ttl(failed.ttl);
671                    state.update_round(failed.round);
672                    let index = usize::from(failed.ttl.0) - 1;
673                    let hop = &mut state.hops[index];
674                    hop.total_sent += 1;
675                    hop.total_failed += 1;
676                    hop.ttl = failed.ttl.0;
677                    hop.samples.insert(0, Duration::default());
678                    if hop.samples.len() > state.max_samples {
679                        hop.samples.pop();
680                    }
681                    hop.last_src_port = failed.src_port.0;
682                    hop.last_dest_port = failed.dest_port.0;
683                    hop.last_sequence = failed.sequence.0;
684                }
685                ProbeStatus::NotSent | ProbeStatus::Skipped => {}
686            }
687        }
688    }
689
690    /// Determine if forward loss has occurred at a given time-to-live.
691    ///
692    /// This is determined by checking if all probes after the awaited probe are all also awaited.
693    fn is_forward_loss(probes: &[ProbeStatus], awaited_ttl: TimeToLive) -> bool {
694        // Skip all probes that have a ttl less than or equal to the awaited ttl. What remains
695        // are the probes we are interested in.
696        let mut remaining = probes
697            .iter()
698            .skip_while(|p| match p {
699                ProbeStatus::Awaited(a) => a.ttl <= awaited_ttl,
700                ProbeStatus::Complete(c) => c.ttl <= awaited_ttl,
701                ProbeStatus::Failed(f) => f.ttl <= awaited_ttl,
702                ProbeStatus::NotSent | ProbeStatus::Skipped => true,
703            })
704            .peekable();
705        let is_empty = remaining.peek().is_none();
706        let all_awaited =
707            remaining.all(|p| matches!(p, ProbeStatus::Awaited(_) | ProbeStatus::Skipped));
708        // If there is at least one probe remaining and all are awaited then we have forward loss.
709        !is_empty && all_awaited
710    }
711
712    /// Determine the NAT detection status.
713    ///
714    /// Returns a tuple of the NAT detection status and the checksum to use for the next hop.
715    const fn nat_status(
716        expected: Checksum,
717        actual: Checksum,
718        prev_hop_checksum: Option<u16>,
719    ) -> (NatStatus, u16) {
720        if let Some(prev_hop_checksum) = prev_hop_checksum {
721            // If the actual checksum matches the checksum of the previous probe
722            // then we can assume NAT has not occurred.  Note that it is perfectly
723            // valid for the expected checksum to differ from the actual checksum
724            // in this case as the NAT'ed checksum "carries forward" throughout the
725            // remainder of the hops on the path.
726            if prev_hop_checksum == actual.0 {
727                (NatStatus::NotDetected, prev_hop_checksum)
728            } else {
729                (NatStatus::Detected, actual.0)
730            }
731        } else {
732            // If we have no prior checksum (i.e. this is the first probe that
733            // responded) and the expected and actual checksums do not match then
734            // we can assume NAT has occurred.
735            if expected.0 == actual.0 {
736                (NatStatus::NotDetected, actual.0)
737            } else {
738                (NatStatus::Detected, actual.0)
739            }
740        }
741    }
742
743    #[cfg(test)]
744    mod tests {
745        use super::*;
746        use crate::probe::ProbeFailed;
747        use crate::{
748            Flags, IcmpPacketType, Port, Probe, ProbeComplete, RoundId, Sequence, TimeToLive,
749            TraceId,
750        };
751        use std::net::{IpAddr, Ipv4Addr};
752        use std::time::SystemTime;
753        use test_case::test_case;
754
755        #[test_case(false, &[], 1; "no forward loss no probes ttl 1")]
756        #[test_case(true, &[('a', 1), ('a', 2)], 1; "forward loss AA ttl 1")]
757        #[test_case(false, &[('a', 1), ('c', 2)], 1; "no forward loss AC ttl 1")]
758        #[test_case(false, &[('a', 1), ('f', 2)], 1; "no forward loss AF ttl 1")]
759        #[test_case(false, &[('a', 1), ('n', 2)], 1; "no forward loss AN ttl 1")]
760        #[test_case(false, &[('a', 1), ('c', 2), ('a', 3), ('a', 4)], 1; "no forward loss ACAA ttl 1")]
761        #[test_case(true, &[('a', 1), ('c', 2), ('a', 3), ('a', 4)], 3; "forward loss ACAA ttl 3")]
762        #[test_case(false, &[('a', 1), ('c', 2), ('a', 3), ('a', 4)], 4; "no forward loss ACAA ttl 4")]
763        #[test_case(false, &[('a', 1), ('f', 2), ('n', 3), ('a', 4)], 4; "no forward loss AFAN ttl 1")]
764        #[test_case(true, &[('a', 4), ('a', 5)], 4; "forward loss AA non-default minimum ttl 4")]
765        #[test_case(false, &[('a', 4), ('c', 5)], 4; "no forward loss AC non-default minimum ttl 4")]
766        #[test_case(false, &[('a', 4), ('c', 5), ('a', 6), ('a', 7)], 4; "no forward loss ACAA non-default minimum ttl 4")]
767        #[test_case(true, &[('a', 4), ('c', 5), ('a', 6), ('a', 7)], 6; "forward loss ACAA non-default minimum ttl 6")]
768        fn test_is_forward_loss(expected: bool, probes: &[(char, u8)], awaited_ttl: u8) {
769            assert!(awaited_ttl > 0);
770            let probes = probes
771                .iter()
772                .map(|(typ, ttl)| {
773                    assert!(matches!(typ, 'n' | 's' | 'f' | 'a' | 'c'));
774                    if *ttl == awaited_ttl {
775                        assert!(matches!(typ, 'a'));
776                    }
777                    match typ {
778                        'n' => ProbeStatus::NotSent,
779                        's' => ProbeStatus::Skipped,
780                        'f' => ProbeStatus::Failed(ProbeFailed {
781                            sequence: Sequence::default(),
782                            identifier: TraceId::default(),
783                            src_port: Port::default(),
784                            dest_port: Port::default(),
785                            ttl: TimeToLive(*ttl),
786                            round: RoundId::default(),
787                            sent: SystemTime::now(),
788                        }),
789                        'a' => ProbeStatus::Awaited(Probe {
790                            sequence: Sequence::default(),
791                            identifier: TraceId::default(),
792                            src_port: Port::default(),
793                            dest_port: Port::default(),
794                            ttl: TimeToLive(*ttl),
795                            round: RoundId(0),
796                            sent: SystemTime::now(),
797                            flags: Flags::empty(),
798                        }),
799                        'c' => ProbeStatus::Complete(ProbeComplete {
800                            sequence: Sequence::default(),
801                            identifier: TraceId::default(),
802                            src_port: Port::default(),
803                            dest_port: Port::default(),
804                            ttl: TimeToLive(*ttl),
805                            round: RoundId::default(),
806                            sent: SystemTime::now(),
807                            host: IpAddr::V4(Ipv4Addr::UNSPECIFIED),
808                            received: SystemTime::now(),
809                            icmp_packet_type: IcmpPacketType::NotApplicable,
810                            tos: None,
811                            expected_udp_checksum: None,
812                            actual_udp_checksum: None,
813                            extensions: None,
814                        }),
815                        _ => unreachable!(),
816                    }
817                })
818                .collect::<Vec<_>>();
819            assert_eq!(is_forward_loss(&probes, TimeToLive(awaited_ttl)), expected);
820        }
821
822        #[test_case(123, 123, None => (NatStatus::NotDetected, 123); "first hop matching checksum")]
823        #[test_case(123, 321, None => (NatStatus::Detected, 321); "first hop non-matching checksum")]
824        #[test_case(123, 123, Some(123) => (NatStatus::NotDetected, 123); "non-first hop matching checksum match previous")]
825        #[test_case(999, 999, Some(321) => (NatStatus::Detected, 999); "non-first hop matching checksum not match previous")]
826        #[test_case(777, 888, Some(321) => (NatStatus::Detected, 888); "non-first hop non-matching checksum not match previous")]
827        const fn test_nat(expected: u16, actual: u16, prev: Option<u16>) -> (NatStatus, u16) {
828            nat_status(Checksum(expected), Checksum(actual), prev)
829        }
830    }
831}
832
833#[cfg(test)]
834mod tests {
835    use super::*;
836    use crate::types::Checksum;
837    use crate::{
838        CompletionReason, Flags, IcmpPacketType, Port, Probe, ProbeComplete, ProbeStatus, Sequence,
839        TimeToLive, TraceId, TypeOfService,
840    };
841    use anyhow::anyhow;
842    use serde::Deserialize;
843    use std::collections::HashSet;
844    use std::fmt::Debug;
845    use std::ops::Add;
846    use std::str::FromStr;
847    use std::time::SystemTime;
848    use test_case::test_case;
849
850    /// A test scenario.
851    #[derive(Deserialize, Debug)]
852    #[serde(deny_unknown_fields)]
853    struct Scenario {
854        /// the biggest ttl expected in this scenario
855        largest_ttl: u8,
856        /// The rounds of probe tracing data in this scenario.
857        rounds: Vec<RoundData>,
858        /// The expected outcome from running this scenario.
859        expected: Expected,
860    }
861
862    /// A single round of tracing probe data.
863    #[derive(Deserialize, Debug)]
864    #[serde(deny_unknown_fields)]
865    struct RoundData {
866        /// The probes in this round.
867        probes: Vec<ProbeData>,
868    }
869
870    /// A single probe from a single round.
871    #[derive(Deserialize, Debug)]
872    #[serde(deny_unknown_fields)]
873    #[serde(try_from = "String")]
874    struct ProbeData(ProbeStatus);
875
876    impl TryFrom<String> for ProbeData {
877        type Error = anyhow::Error;
878
879        fn try_from(value: String) -> Result<Self, Self::Error> {
880            // format: `{ttl} {status} {duration} {host} {sequence} {src_port} {dest_port} {checksum} {tos}`
881            let values = value.split_ascii_whitespace().collect::<Vec<_>>();
882            if values.len() == 10 {
883                let ttl = TimeToLive(u8::from_str(values[0])?);
884                let state = values[1].to_ascii_lowercase();
885                let sequence = Sequence(u16::from_str(values[4])?);
886                let src_port = Port(u16::from_str(values[5])?);
887                let dest_port = Port(u16::from_str(values[6])?);
888                let round = RoundId(0); // note we inject this later, see `ProbeRound`
889                let sent = SystemTime::now();
890                let flags = Flags::empty();
891                let state = match state.as_str() {
892                    "n" => Ok(ProbeStatus::NotSent),
893                    "s" => Ok(ProbeStatus::Skipped),
894                    "a" => Ok(ProbeStatus::Awaited(Probe::new(
895                        sequence,
896                        TraceId(0),
897                        src_port,
898                        dest_port,
899                        ttl,
900                        round,
901                        sent,
902                        flags,
903                    ))),
904                    "c" => {
905                        let host = IpAddr::from_str(values[3])?;
906                        let duration = Duration::from_millis(u64::from_str(values[2])?);
907                        let received = sent.add(duration);
908                        let expected_udp_checksum = Some(Checksum(u16::from_str(values[7])?));
909                        let actual_udp_checksum = Some(Checksum(u16::from_str(values[8])?));
910                        let icmp_packet_type = IcmpPacketType::NotApplicable;
911                        let tos = Some(TypeOfService(u8::from_str(values[9])?));
912                        Ok(ProbeStatus::Complete(
913                            Probe::new(
914                                sequence,
915                                TraceId(0),
916                                src_port,
917                                dest_port,
918                                ttl,
919                                round,
920                                sent,
921                                flags,
922                            )
923                            .complete(
924                                host,
925                                received,
926                                icmp_packet_type,
927                                tos,
928                                expected_udp_checksum,
929                                actual_udp_checksum,
930                                None,
931                            ),
932                        ))
933                    }
934                    _ => Err(anyhow!("unknown probe state")),
935                }?;
936                Ok(Self(state))
937            } else {
938                Err(anyhow!("failed to parse {}", value))
939            }
940        }
941    }
942
943    /// A helper struct so we may inject the round into the probes.
944    struct ProbeRound(ProbeData, RoundId);
945
946    impl From<ProbeRound> for ProbeStatus {
947        fn from(value: ProbeRound) -> Self {
948            let probe_data = value.0;
949            let round = value.1;
950            match probe_data.0 {
951                Self::NotSent => Self::NotSent,
952                Self::Skipped => Self::Skipped,
953                Self::Awaited(awaited) => Self::Awaited(Probe { round, ..awaited }),
954                Self::Complete(completed) => Self::Complete(ProbeComplete { round, ..completed }),
955                Self::Failed(failed) => Self::Failed(failed),
956            }
957        }
958    }
959
960    /// The expected outcome.
961    #[derive(Deserialize, Debug, Clone)]
962    #[serde(deny_unknown_fields)]
963    struct Expected {
964        /// The expected outcome per hop.
965        hops: Vec<HopData>,
966    }
967
968    /// The expected outcome for a single hop.
969    #[derive(Deserialize, Debug, Clone)]
970    #[serde(deny_unknown_fields)]
971    struct HopData {
972        ttl: Option<u8>,
973        total_sent: Option<usize>,
974        total_recv: Option<usize>,
975        total_forward_loss: Option<usize>,
976        total_backward_loss: Option<usize>,
977        loss_pct: Option<f64>,
978        last_ms: Option<f64>,
979        best_ms: Option<f64>,
980        worst_ms: Option<f64>,
981        avg_ms: Option<f64>,
982        jitter: Option<f64>,
983        javg: Option<f64>,
984        jmax: Option<f64>,
985        jinta: Option<f64>,
986        addrs: Option<HashMap<IpAddr, usize>>,
987        samples: Option<Vec<f64>>,
988        last_src: Option<u16>,
989        last_dest: Option<u16>,
990        last_sequence: Option<u16>,
991        last_nat_status: Option<NatStatusWrapper>,
992        tos: Option<u8>,
993    }
994
995    /// A wrapper struct over `NatStatus` to allow deserialization.
996    #[derive(Deserialize, Debug, Clone)]
997    #[serde(try_from = "String")]
998    struct NatStatusWrapper(NatStatus);
999
1000    impl TryFrom<String> for NatStatusWrapper {
1001        type Error = anyhow::Error;
1002
1003        fn try_from(value: String) -> Result<Self, Self::Error> {
1004            match value.to_ascii_lowercase().as_str() {
1005                "none" => Ok(Self(NatStatus::NotApplicable)),
1006                "nat" => Ok(Self(NatStatus::Detected)),
1007                "no_nat" => Ok(Self(NatStatus::NotDetected)),
1008                _ => Err(anyhow!("unknown nat status")),
1009            }
1010        }
1011    }
1012
1013    macro_rules! file {
1014        ($path:expr) => {{
1015            let data = include_str!(concat!("../tests/resources/state/", $path));
1016            toml::from_str(data).unwrap()
1017        }};
1018    }
1019
1020    #[test_case(file!("full_mixed.toml"))]
1021    #[test_case(file!("full_completed.toml"))]
1022    #[test_case(file!("all_status.toml"))]
1023    #[test_case(file!("no_latency.toml"))]
1024    #[test_case(file!("nat.toml"))]
1025    #[test_case(file!("minimal.toml"))]
1026    #[test_case(file!("floss_bloss.toml"))]
1027    #[test_case(file!("non_default_minimum_ttl.toml"))]
1028    #[test_case(file!("tos.toml"))]
1029    fn test_scenario(scenario: Scenario) {
1030        let mut trace = State::new(StateConfig {
1031            max_flows: 1,
1032            ..StateConfig::default()
1033        });
1034        for (i, round) in scenario.rounds.into_iter().enumerate() {
1035            let probes = round
1036                .probes
1037                .into_iter()
1038                .map(|p| ProbeRound(p, RoundId(i)))
1039                .map(Into::into)
1040                .collect::<Vec<_>>();
1041            let largest_ttl = TimeToLive(scenario.largest_ttl);
1042            let tracer_round = Round::new(&probes, largest_ttl, CompletionReason::TargetFound);
1043            trace.update_from_round(&tracer_round);
1044        }
1045        let actual_hops = trace.hops();
1046        let expected_hops = scenario.expected.hops;
1047        for (actual, expected) in actual_hops.iter().zip(expected_hops) {
1048            assert_eq_opt(Some(&actual.ttl()), expected.ttl.as_ref());
1049            assert_eq_opt(
1050                Some(actual.addrs().collect::<HashSet<_>>()),
1051                expected
1052                    .addrs
1053                    .as_ref()
1054                    .map(|addrs| addrs.keys().collect::<HashSet<_>>()),
1055            );
1056            assert_eq_opt(
1057                Some(actual.addr_count()),
1058                expected.addrs.as_ref().map(HashMap::len),
1059            );
1060            assert_eq_opt(Some(&actual.total_sent()), expected.total_sent.as_ref());
1061            assert_eq_opt(Some(&actual.total_recv()), expected.total_recv.as_ref());
1062            assert_eq_opt_f64(Some(&actual.loss_pct()), expected.loss_pct.as_ref());
1063            assert_eq_opt(
1064                Some(&actual.total_forward_loss()),
1065                expected.total_forward_loss.as_ref(),
1066            );
1067            assert_eq_opt(
1068                Some(&actual.total_backward_loss()),
1069                expected.total_backward_loss.as_ref(),
1070            );
1071            assert_eq_opt_f64(actual.last_ms().as_ref(), expected.last_ms.as_ref());
1072            assert_eq_opt_f64(actual.best_ms().as_ref(), expected.best_ms.as_ref());
1073            assert_eq_opt_f64(actual.worst_ms().as_ref(), expected.worst_ms.as_ref());
1074            assert_eq_opt_f64(Some(&actual.avg_ms()), expected.avg_ms.as_ref());
1075            assert_eq_opt_f64(actual.jitter_ms().as_ref(), expected.jitter.as_ref());
1076            assert_eq_opt_f64(Some(&actual.javg_ms()), expected.javg.as_ref());
1077            assert_eq_opt_f64(actual.jmax_ms().as_ref(), expected.jmax.as_ref());
1078            assert_eq_opt_f64(Some(&actual.jinta()), expected.jinta.as_ref());
1079            assert_eq_opt(Some(&actual.last_src_port()), expected.last_src.as_ref());
1080            assert_eq_opt(Some(&actual.last_dest_port()), expected.last_dest.as_ref());
1081            assert_eq_opt(
1082                Some(&actual.last_sequence()),
1083                expected.last_sequence.as_ref(),
1084            );
1085            assert_eq_opt(
1086                Some(&actual.last_nat_status()),
1087                expected.last_nat_status.as_ref().map(|nat| &nat.0),
1088            );
1089            assert_eq_vec_f64(
1090                Some(
1091                    &actual
1092                        .samples()
1093                        .iter()
1094                        .map(|s| s.as_secs_f64() * 1000_f64)
1095                        .collect(),
1096                ),
1097                expected.samples.as_ref(),
1098            );
1099            assert_eq_opt(actual.tos().map(|tos| tos.0), expected.tos);
1100        }
1101    }
1102
1103    #[allow(clippy::needless_pass_by_value)]
1104    fn assert_eq_opt<T: Eq + Debug>(actual: Option<T>, expected: Option<T>) {
1105        assert_eq_inner(actual.as_ref(), expected.as_ref(), |a, e| a == e);
1106    }
1107
1108    fn assert_eq_opt_f64(actual: Option<&f64>, expected: Option<&f64>) {
1109        assert_eq_inner(actual, expected, |a, e| (e - a).abs() < f64::EPSILON);
1110    }
1111
1112    fn assert_eq_vec_f64(actual: Option<&Vec<f64>>, expected: Option<&Vec<f64>>) {
1113        assert_eq_inner(actual, expected, |a, e| {
1114            if a.len() != e.len() {
1115                return false;
1116            }
1117            a.iter()
1118                .zip(e.iter())
1119                .all(|(a, e)| (e - a).abs() < f64::EPSILON)
1120        });
1121    }
1122
1123    fn assert_eq_inner<T: Debug>(
1124        actual: Option<&T>,
1125        expected: Option<&T>,
1126        eq: impl Fn(&T, &T) -> bool,
1127    ) {
1128        match (actual, expected) {
1129            (Some(actual), Some(expected)) if eq(actual, expected) => {}
1130            (Some(actual), Some(expected)) => {
1131                panic!("expected {expected:?} did not match actual {actual:?}")
1132            }
1133            (None, Some(_)) => panic!("expected {expected:?} but no actual"),
1134            (_, None) => {}
1135        }
1136    }
1137}