use self::state::TracerState;
use crate::config::StrategyConfig;
use crate::error::{Error, Result};
use crate::net::Network;
use crate::probe::{
ProbeStatus, Response, ResponseData, ResponseSeq, ResponseSeqIcmp, ResponseSeqTcp,
ResponseSeqUdp,
};
use crate::types::{Checksum, Sequence, TimeToLive, TraceId};
use crate::{Extensions, IcmpPacketType, MultipathStrategy, PortDirection, Probe, Protocol};
use std::net::IpAddr;
use std::time::{Duration, SystemTime};
use tracing::instrument;
#[derive(Debug, Clone)]
pub struct Round<'a> {
pub probes: &'a [ProbeStatus],
pub largest_ttl: TimeToLive,
pub reason: CompletionReason,
}
impl<'a> Round<'a> {
#[must_use]
pub const fn new(
probes: &'a [ProbeStatus],
largest_ttl: TimeToLive,
reason: CompletionReason,
) -> Self {
Self {
probes,
largest_ttl,
reason,
}
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum CompletionReason {
TargetFound,
RoundTimeLimitExceeded,
}
#[derive(Debug, Clone)]
pub struct Strategy<F> {
config: StrategyConfig,
publish: F,
}
impl<F: Fn(&Round<'_>)> Strategy<F> {
#[instrument(skip_all)]
pub fn new(config: &StrategyConfig, publish: F) -> Self {
tracing::debug!(?config);
Self {
config: *config,
publish,
}
}
#[instrument(skip(self, network))]
pub fn run<N: Network>(self, mut network: N) -> Result<()> {
let mut state = TracerState::new(self.config);
while !state.finished(self.config.max_rounds) {
self.send_request(&mut network, &mut state)?;
self.recv_response(&mut network, &mut state)?;
self.update_round(&mut state);
}
Ok(())
}
#[instrument(skip(self, network, st))]
fn send_request<N: Network>(&self, network: &mut N, st: &mut TracerState) -> Result<()> {
let can_send_ttl = if let Some(target_ttl) = st.target_ttl() {
st.ttl() <= target_ttl
} else {
st.ttl() - st.max_received_ttl().unwrap_or_default()
< TimeToLive(self.config.max_inflight.0)
};
if !st.target_found() && st.ttl() <= self.config.max_ttl && can_send_ttl {
let sent = SystemTime::now();
match self.config.protocol {
Protocol::Icmp => {
let probe = st.next_probe(sent);
Self::do_send(network, st, probe)?;
}
Protocol::Udp => {
let probe = st.next_probe(sent);
Self::do_send(network, st, probe)?;
}
Protocol::Tcp => {
let mut probe = if st.round_has_capacity() {
st.next_probe(sent)
} else {
return Err(Error::InsufficientCapacity);
};
while let Err(err) = Self::do_send(network, st, probe) {
match err {
Error::AddressInUse(_) => {
if st.round_has_capacity() {
probe = st.reissue_probe(SystemTime::now());
} else {
return Err(Error::InsufficientCapacity);
}
}
other => return Err(other),
}
}
}
};
}
Ok(())
}
fn do_send<N: Network>(network: &mut N, st: &mut TracerState, probe: Probe) -> Result<()> {
match network.send_probe(probe) {
Ok(()) => Ok(()),
Err(Error::ProbeFailed(_)) => {
st.fail_probe();
Ok(())
}
Err(err) => Err(err),
}
}
#[instrument(skip(self, network, st))]
fn recv_response<N: Network>(&self, network: &mut N, st: &mut TracerState) -> Result<()> {
let next = network.recv_probe()?;
if let Some(resp) = next {
if self.validate(resp.data()) {
let resp = StrategyResponse::from((resp, &self.config));
if self.check_trace_id(resp.trace_id) && st.in_round(resp.sequence) {
st.complete_probe(resp);
}
}
}
Ok(())
}
#[instrument(skip(self, st))]
fn update_round(&self, st: &mut TracerState) {
let now = SystemTime::now();
let round_duration = now.duration_since(st.round_start()).unwrap_or_default();
let round_min = round_duration > self.config.min_round_duration;
let grace_exceeded = exceeds(st.received_time(), now, self.config.grace_duration);
let round_max = round_duration > self.config.max_round_duration;
let target_found = st.target_found();
if round_min && grace_exceeded && target_found || round_max {
self.publish_trace(st);
st.advance_round(self.config.first_ttl);
}
}
#[instrument(skip(self, state))]
fn publish_trace(&self, state: &TracerState) {
let max_received_ttl = if let Some(target_ttl) = state.target_ttl() {
target_ttl
} else {
state
.max_received_ttl()
.map_or(TimeToLive(0), |max_received_ttl| {
let max_sent_ttl = state.ttl() - TimeToLive(1);
max_sent_ttl.min(max_received_ttl + TimeToLive(1))
})
};
let probes = state.probes();
let largest_ttl = max_received_ttl;
let reason = if state.target_found() {
CompletionReason::TargetFound
} else {
CompletionReason::RoundTimeLimitExceeded
};
(self.publish)(&Round::new(probes, largest_ttl, reason));
}
#[instrument(skip(self))]
fn check_trace_id(&self, trace_id: TraceId) -> bool {
self.config.trace_identifier == trace_id || trace_id == TraceId(0)
}
fn validate(&self, resp: &ResponseData) -> bool {
const fn validate_ports(
port_direction: PortDirection,
src_port: u16,
dest_port: u16,
) -> bool {
match port_direction {
PortDirection::FixedSrc(src) if src.0 == src_port => true,
PortDirection::FixedDest(dest) if dest.0 == dest_port => true,
PortDirection::FixedBoth(src, dest) if src.0 == src_port && dest.0 == dest_port => {
true
}
_ => false,
}
}
match resp.resp_seq {
ResponseSeq::Icmp(_) => true,
ResponseSeq::Udp(ResponseSeqUdp {
dest_addr,
src_port,
dest_port,
has_magic,
..
}) => {
let check_ports = validate_ports(self.config.port_direction, src_port, dest_port);
let check_dest_addr = self.config.target_addr == dest_addr;
let check_magic = match (self.config.multipath_strategy, self.config.target_addr) {
(MultipathStrategy::Dublin, IpAddr::V6(_)) => has_magic,
_ => true,
};
check_dest_addr && check_ports && check_magic
}
ResponseSeq::Tcp(ResponseSeqTcp {
dest_addr,
src_port,
dest_port,
}) => {
let check_ports = validate_ports(self.config.port_direction, src_port, dest_port);
let check_dest_addr = self.config.target_addr == dest_addr;
check_dest_addr && check_ports
}
}
}
}
#[derive(Debug)]
struct StrategyResponse {
icmp_packet_type: IcmpPacketType,
trace_id: TraceId,
sequence: Sequence,
expected_udp_checksum: Option<Checksum>,
actual_udp_checksum: Option<Checksum>,
received: SystemTime,
addr: IpAddr,
is_target: bool,
exts: Option<Extensions>,
}
impl From<(Response, &StrategyConfig)> for StrategyResponse {
fn from((resp, config): (Response, &StrategyConfig)) -> Self {
match resp {
Response::TimeExceeded(data, code, exts) => {
let resp_seq = StrategyResponseSeq::from((data.resp_seq, config));
let is_target = data.addr == config.target_addr;
Self {
icmp_packet_type: IcmpPacketType::TimeExceeded(code),
trace_id: resp_seq.trace_id,
sequence: resp_seq.sequence,
expected_udp_checksum: resp_seq.expected_udp_checksum,
actual_udp_checksum: resp_seq.actual_udp_checksum,
received: data.recv,
addr: data.addr,
is_target,
exts,
}
}
Response::DestinationUnreachable(data, code, exts) => {
let resp_seq = StrategyResponseSeq::from((data.resp_seq, config));
let is_target = data.addr == config.target_addr;
Self {
icmp_packet_type: IcmpPacketType::Unreachable(code),
trace_id: resp_seq.trace_id,
sequence: resp_seq.sequence,
expected_udp_checksum: resp_seq.expected_udp_checksum,
actual_udp_checksum: resp_seq.actual_udp_checksum,
received: data.recv,
addr: data.addr,
is_target,
exts,
}
}
Response::EchoReply(data, code) => {
let resp_seq = StrategyResponseSeq::from((data.resp_seq, config));
Self {
icmp_packet_type: IcmpPacketType::EchoReply(code),
trace_id: resp_seq.trace_id,
sequence: resp_seq.sequence,
expected_udp_checksum: resp_seq.expected_udp_checksum,
actual_udp_checksum: resp_seq.actual_udp_checksum,
received: data.recv,
addr: data.addr,
is_target: true,
exts: None,
}
}
Response::TcpReply(data) | Response::TcpRefused(data) => {
let resp_seq = StrategyResponseSeq::from((data.resp_seq, config));
Self {
icmp_packet_type: IcmpPacketType::NotApplicable,
trace_id: resp_seq.trace_id,
sequence: resp_seq.sequence,
expected_udp_checksum: resp_seq.expected_udp_checksum,
actual_udp_checksum: resp_seq.actual_udp_checksum,
received: data.recv,
addr: data.addr,
is_target: true,
exts: None,
}
}
}
}
}
#[derive(Debug)]
struct StrategyResponseSeq {
trace_id: TraceId,
sequence: Sequence,
expected_udp_checksum: Option<Checksum>,
actual_udp_checksum: Option<Checksum>,
}
impl From<(ResponseSeq, &StrategyConfig)> for StrategyResponseSeq {
fn from((resp_seq, config): (ResponseSeq, &StrategyConfig)) -> Self {
match resp_seq {
ResponseSeq::Icmp(ResponseSeqIcmp {
identifier,
sequence,
}) => Self {
trace_id: TraceId(identifier),
sequence: Sequence(sequence),
expected_udp_checksum: None,
actual_udp_checksum: None,
},
ResponseSeq::Udp(ResponseSeqUdp {
identifier,
src_port,
dest_port,
expected_udp_checksum,
actual_udp_checksum,
payload_len,
..
}) => {
let sequence = match (
config.multipath_strategy,
config.port_direction,
config.target_addr,
) {
(MultipathStrategy::Classic, PortDirection::FixedDest(_), _) => src_port,
(MultipathStrategy::Classic, _, _) => dest_port,
(MultipathStrategy::Paris, _, _) => actual_udp_checksum,
(MultipathStrategy::Dublin, _, IpAddr::V4(_)) => identifier,
(MultipathStrategy::Dublin, _, IpAddr::V6(_)) => {
config.initial_sequence.0 + payload_len
}
};
let (expected_udp_checksum, actual_udp_checksum) =
match (config.multipath_strategy, config.target_addr) {
(MultipathStrategy::Dublin, IpAddr::V4(_)) => (
Some(Checksum(expected_udp_checksum)),
Some(Checksum(actual_udp_checksum)),
),
_ => (None, None),
};
Self {
trace_id: TraceId(0),
sequence: Sequence(sequence),
expected_udp_checksum,
actual_udp_checksum,
}
}
ResponseSeq::Tcp(ResponseSeqTcp {
src_port,
dest_port,
..
}) => {
let sequence = match config.port_direction {
PortDirection::FixedSrc(_) => dest_port,
_ => src_port,
};
Self {
trace_id: TraceId(0),
sequence: Sequence(sequence),
expected_udp_checksum: None,
actual_udp_checksum: None,
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::net::MockNetwork;
use crate::probe::IcmpPacketCode;
use crate::{MaxRounds, Port};
use std::net::Ipv4Addr;
use std::num::NonZeroUsize;
#[test]
fn test_time_exceeded_target_response() {
let config = StrategyConfig {
target_addr: IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
..Default::default()
};
let now = SystemTime::now();
let resp_data = Response::TimeExceeded(response_data(now), IcmpPacketCode(1), None);
let resp = StrategyResponse::from((resp_data, &config));
assert_eq!(
resp.icmp_packet_type,
IcmpPacketType::TimeExceeded(IcmpPacketCode(1))
);
assert_eq!(resp.trace_id, TraceId(0));
assert_eq!(resp.sequence, Sequence(33434));
assert_eq!(resp.received, now);
assert_eq!(resp.addr, IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)));
assert_eq!(resp.is_target, true);
assert!(resp.exts.is_none());
}
#[test]
fn test_time_exceeded_not_target_response() {
let config = StrategyConfig {
target_addr: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
..Default::default()
};
let now = SystemTime::now();
let resp_data = Response::TimeExceeded(response_data(now), IcmpPacketCode(1), None);
let resp = StrategyResponse::from((resp_data, &config));
assert_eq!(
resp.icmp_packet_type,
IcmpPacketType::TimeExceeded(IcmpPacketCode(1))
);
assert_eq!(resp.trace_id, TraceId(0));
assert_eq!(resp.sequence, Sequence(33434));
assert_eq!(resp.received, now);
assert_eq!(resp.addr, IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)));
assert_eq!(resp.is_target, false);
assert!(resp.exts.is_none());
}
#[test]
fn test_destination_unreachable_target_response() {
let config = StrategyConfig {
target_addr: IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
..Default::default()
};
let now = SystemTime::now();
let resp_data =
Response::DestinationUnreachable(response_data(now), IcmpPacketCode(10), None);
let resp = StrategyResponse::from((resp_data, &config));
assert_eq!(
resp.icmp_packet_type,
IcmpPacketType::Unreachable(IcmpPacketCode(10))
);
assert_eq!(resp.trace_id, TraceId(0));
assert_eq!(resp.sequence, Sequence(33434));
assert_eq!(resp.received, now);
assert_eq!(resp.addr, IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)));
assert_eq!(resp.is_target, true);
assert!(resp.exts.is_none());
}
#[test]
fn test_destination_unreachable_not_target_response() {
let config = StrategyConfig::default();
let now = SystemTime::now();
let resp_data =
Response::DestinationUnreachable(response_data(now), IcmpPacketCode(10), None);
let resp = StrategyResponse::from((resp_data, &config));
assert_eq!(
resp.icmp_packet_type,
IcmpPacketType::Unreachable(IcmpPacketCode(10))
);
assert_eq!(resp.trace_id, TraceId(0));
assert_eq!(resp.sequence, Sequence(33434));
assert_eq!(resp.received, now);
assert_eq!(resp.addr, IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)));
assert_eq!(resp.is_target, false);
assert!(resp.exts.is_none());
}
#[test]
fn test_echo_reply_response() {
let config = StrategyConfig::default();
let now = SystemTime::now();
let resp_data = Response::EchoReply(response_data(now), IcmpPacketCode(99));
let resp = StrategyResponse::from((resp_data, &config));
assert_eq!(
resp.icmp_packet_type,
IcmpPacketType::EchoReply(IcmpPacketCode(99))
);
assert_eq!(resp.trace_id, TraceId(0));
assert_eq!(resp.sequence, Sequence(33434));
assert_eq!(resp.received, now);
assert_eq!(resp.addr, IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)));
assert_eq!(resp.is_target, true);
assert!(resp.exts.is_none());
}
#[test]
fn test_tcp_reply_response() {
let config = StrategyConfig::default();
let now = SystemTime::now();
let resp_data = Response::TcpReply(response_data(now));
let resp = StrategyResponse::from((resp_data, &config));
assert_eq!(resp.icmp_packet_type, IcmpPacketType::NotApplicable);
assert_eq!(resp.trace_id, TraceId(0));
assert_eq!(resp.sequence, Sequence(33434));
assert_eq!(resp.received, now);
assert_eq!(resp.addr, IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)));
assert_eq!(resp.is_target, true);
assert!(resp.exts.is_none());
}
#[test]
fn test_tcp_refused_response() {
let config = StrategyConfig::default();
let now = SystemTime::now();
let resp_data = Response::TcpRefused(response_data(now));
let resp = StrategyResponse::from((resp_data, &config));
assert_eq!(resp.icmp_packet_type, IcmpPacketType::NotApplicable);
assert_eq!(resp.trace_id, TraceId(0));
assert_eq!(resp.sequence, Sequence(33434));
assert_eq!(resp.received, now);
assert_eq!(resp.addr, IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)));
assert_eq!(resp.is_target, true);
assert!(resp.exts.is_none());
}
#[test]
fn test_icmp_response() {
let config = StrategyConfig::default();
let resp_seq = ResponseSeq::Icmp(ResponseSeqIcmp {
identifier: 1234,
sequence: 33434,
});
let strategy_resp = StrategyResponseSeq::from((resp_seq, &config));
assert_eq!(strategy_resp.trace_id, TraceId(1234));
assert_eq!(strategy_resp.sequence, Sequence(33434));
}
#[test]
fn test_udp_classic_fixed_src_response() {
let config = StrategyConfig {
protocol: Protocol::Udp,
port_direction: PortDirection::FixedSrc(Port(5000)),
..Default::default()
};
let resp_seq = ResponseSeq::Udp(ResponseSeqUdp {
identifier: 0,
dest_addr: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
src_port: 5000,
dest_port: 33434,
expected_udp_checksum: 0,
actual_udp_checksum: 0,
payload_len: 0,
has_magic: false,
});
let strategy_resp = StrategyResponseSeq::from((resp_seq, &config));
assert_eq!(strategy_resp.trace_id, TraceId(0));
assert_eq!(strategy_resp.sequence, Sequence(33434));
}
#[test]
fn test_udp_classic_fixed_dest_response() {
let config = StrategyConfig {
protocol: Protocol::Udp,
port_direction: PortDirection::FixedDest(Port(5000)),
..Default::default()
};
let resp_seq = ResponseSeq::Udp(ResponseSeqUdp {
identifier: 0,
dest_addr: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
src_port: 33434,
dest_port: 5000,
expected_udp_checksum: 0,
actual_udp_checksum: 0,
payload_len: 0,
has_magic: false,
});
let strategy_resp = StrategyResponseSeq::from((resp_seq, &config));
assert_eq!(strategy_resp.trace_id, TraceId(0));
assert_eq!(strategy_resp.sequence, Sequence(33434));
}
#[test]
fn test_udp_paris_response() {
let config = StrategyConfig {
protocol: Protocol::Udp,
multipath_strategy: MultipathStrategy::Paris,
port_direction: PortDirection::FixedSrc(Port(5000)),
..Default::default()
};
let resp_seq = ResponseSeq::Udp(ResponseSeqUdp {
identifier: 33434,
dest_addr: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
src_port: 5000,
dest_port: 35000,
expected_udp_checksum: 33434,
actual_udp_checksum: 33434,
payload_len: 0,
has_magic: false,
});
let strategy_resp = StrategyResponseSeq::from((resp_seq, &config));
assert_eq!(strategy_resp.trace_id, TraceId(0));
assert_eq!(strategy_resp.sequence, Sequence(33434));
}
#[test]
fn test_udp_dublin_ipv4_response() {
let config = StrategyConfig {
protocol: Protocol::Udp,
multipath_strategy: MultipathStrategy::Dublin,
port_direction: PortDirection::FixedSrc(Port(5000)),
..Default::default()
};
let resp_seq = ResponseSeq::Udp(ResponseSeqUdp {
identifier: 33434,
dest_addr: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
src_port: 5000,
dest_port: 35000,
expected_udp_checksum: 0,
actual_udp_checksum: 0,
payload_len: 0,
has_magic: false,
});
let strategy_resp = StrategyResponseSeq::from((resp_seq, &config));
assert_eq!(strategy_resp.trace_id, TraceId(0));
assert_eq!(strategy_resp.sequence, Sequence(33434));
}
#[test]
fn test_udp_dublin_ipv6_response() {
let config = StrategyConfig {
protocol: Protocol::Udp,
target_addr: IpAddr::V6("::1".parse().unwrap()),
multipath_strategy: MultipathStrategy::Dublin,
port_direction: PortDirection::FixedSrc(Port(5000)),
..Default::default()
};
let resp_seq = ResponseSeq::Udp(ResponseSeqUdp {
identifier: 0,
dest_addr: IpAddr::V6("::1".parse().unwrap()),
src_port: 5000,
dest_port: 35000,
expected_udp_checksum: 0,
actual_udp_checksum: 0,
payload_len: 55,
has_magic: true,
});
let strategy_resp = StrategyResponseSeq::from((resp_seq, &config));
assert_eq!(strategy_resp.trace_id, TraceId(0));
assert_eq!(strategy_resp.sequence, Sequence(33489));
}
#[test]
fn test_tcp_fixed_dest_response() {
let config = StrategyConfig {
protocol: Protocol::Tcp,
port_direction: PortDirection::FixedDest(Port(80)),
..Default::default()
};
let resp_seq = ResponseSeq::Udp(ResponseSeqUdp {
identifier: 0,
dest_addr: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
src_port: 33434,
dest_port: 80,
expected_udp_checksum: 0,
actual_udp_checksum: 0,
payload_len: 0,
has_magic: false,
});
let strategy_resp = StrategyResponseSeq::from((resp_seq, &config));
assert_eq!(strategy_resp.trace_id, TraceId(0));
assert_eq!(strategy_resp.sequence, Sequence(33434));
}
#[test]
fn test_tcp_fixed_src_response() {
let config = StrategyConfig {
protocol: Protocol::Tcp,
port_direction: PortDirection::FixedSrc(Port(5000)),
..Default::default()
};
let resp_seq = ResponseSeq::Udp(ResponseSeqUdp {
identifier: 0,
dest_addr: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
src_port: 5000,
dest_port: 33434,
expected_udp_checksum: 0,
actual_udp_checksum: 0,
payload_len: 0,
has_magic: false,
});
let strategy_resp = StrategyResponseSeq::from((resp_seq, &config));
assert_eq!(strategy_resp.trace_id, TraceId(0));
assert_eq!(strategy_resp.sequence, Sequence(33434));
}
#[test]
fn test_tcp_dest_unreachable_and_refused() -> anyhow::Result<()> {
let sequence = 33434;
let target_addr = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
let mut network = MockNetwork::new();
let mut seq = mockall::Sequence::new();
network.expect_send_probe().times(1).returning(|_| Ok(()));
network
.expect_recv_probe()
.times(1)
.in_sequence(&mut seq)
.returning(move || {
Ok(Some(Response::DestinationUnreachable(
ResponseData::new(
SystemTime::now(),
target_addr,
ResponseSeq::Tcp(ResponseSeqTcp::new(target_addr, sequence, 80)),
),
IcmpPacketCode(1),
None,
)))
});
network
.expect_recv_probe()
.times(1)
.in_sequence(&mut seq)
.returning(move || {
Ok(Some(Response::TcpRefused(ResponseData::new(
SystemTime::now(),
target_addr,
ResponseSeq::Tcp(ResponseSeqTcp::new(target_addr, sequence, 80)),
))))
});
let config = StrategyConfig {
target_addr,
max_rounds: Some(MaxRounds(NonZeroUsize::MIN)),
initial_sequence: Sequence(sequence),
port_direction: PortDirection::FixedDest(Port(80)),
protocol: Protocol::Tcp,
..Default::default()
};
let tracer = Strategy::new(&config, |_| {});
let mut state = TracerState::new(config);
tracer.send_request(&mut network, &mut state)?;
tracer.recv_response(&mut network, &mut state)?;
tracer.recv_response(&mut network, &mut state)?;
Ok(())
}
const fn response_data(now: SystemTime) -> ResponseData {
ResponseData::new(
now,
IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
ResponseSeq::Icmp(ResponseSeqIcmp {
identifier: 0,
sequence: 33434,
}),
)
}
}
mod state {
use crate::constants::MAX_SEQUENCE_PER_ROUND;
use crate::probe::{Probe, ProbeStatus};
use crate::strategy::{StrategyConfig, StrategyResponse};
use crate::types::{MaxRounds, Port, RoundId, Sequence, TimeToLive, TraceId};
use crate::{Flags, MultipathStrategy, PortDirection, Protocol};
use std::array::from_fn;
use std::net::IpAddr;
use std::time::SystemTime;
use tracing::instrument;
const BUFFER_SIZE: u16 = MAX_SEQUENCE_PER_ROUND;
const MAX_SEQUENCE: Sequence = Sequence(u16::MAX - BUFFER_SIZE);
#[derive(Debug)]
pub struct TracerState {
config: StrategyConfig,
buffer: [ProbeStatus; BUFFER_SIZE as usize],
sequence: Sequence,
round_sequence: Sequence,
ttl: TimeToLive,
round: RoundId,
round_start: SystemTime,
target_found: bool,
max_received_ttl: Option<TimeToLive>,
target_ttl: Option<TimeToLive>,
received_time: Option<SystemTime>,
}
impl TracerState {
pub fn new(config: StrategyConfig) -> Self {
Self {
config,
buffer: from_fn(|_| ProbeStatus::default()),
sequence: config.initial_sequence,
round_sequence: config.initial_sequence,
ttl: config.first_ttl,
round: RoundId(0),
round_start: SystemTime::now(),
target_found: false,
max_received_ttl: None,
target_ttl: None,
received_time: None,
}
}
pub fn probes(&self) -> &[ProbeStatus] {
let round_size = self.sequence - self.round_sequence;
&self.buffer[..round_size.0 as usize]
}
pub fn probe_at(&self, sequence: Sequence) -> ProbeStatus {
self.buffer[usize::from(sequence - self.round_sequence)].clone()
}
pub const fn ttl(&self) -> TimeToLive {
self.ttl
}
pub const fn round_start(&self) -> SystemTime {
self.round_start
}
pub const fn target_found(&self) -> bool {
self.target_found
}
pub const fn max_received_ttl(&self) -> Option<TimeToLive> {
self.max_received_ttl
}
pub const fn target_ttl(&self) -> Option<TimeToLive> {
self.target_ttl
}
pub const fn received_time(&self) -> Option<SystemTime> {
self.received_time
}
pub fn in_round(&self, sequence: Sequence) -> bool {
sequence >= self.round_sequence && sequence.0 - self.round_sequence.0 < BUFFER_SIZE
}
pub fn round_has_capacity(&self) -> bool {
let round_size = self.sequence - self.round_sequence;
round_size.0 < BUFFER_SIZE
}
pub const fn finished(&self, max_rounds: Option<MaxRounds>) -> bool {
match max_rounds {
None => false,
Some(max_rounds) => self.round.0 > max_rounds.0.get() - 1,
}
}
#[instrument(skip(self))]
pub fn next_probe(&mut self, sent: SystemTime) -> Probe {
let (src_port, dest_port, identifier, flags) = self.probe_data();
let probe = Probe::new(
self.sequence,
identifier,
src_port,
dest_port,
self.ttl,
self.round,
sent,
flags,
);
let probe_index = usize::from(self.sequence - self.round_sequence);
self.buffer[probe_index] = ProbeStatus::Awaited(probe.clone());
debug_assert!(self.ttl < TimeToLive(u8::MAX));
self.ttl += TimeToLive(1);
debug_assert!(self.sequence < Sequence(u16::MAX));
self.sequence += Sequence(1);
probe
}
#[instrument(skip(self))]
pub fn reissue_probe(&mut self, sent: SystemTime) -> Probe {
let probe_index = usize::from(self.sequence - self.round_sequence);
self.buffer[probe_index - 1] = ProbeStatus::Skipped;
let (src_port, dest_port, identifier, flags) = self.probe_data();
let probe = Probe::new(
self.sequence,
identifier,
src_port,
dest_port,
self.ttl - TimeToLive(1),
self.round,
sent,
flags,
);
self.buffer[probe_index] = ProbeStatus::Awaited(probe.clone());
debug_assert!(self.sequence < Sequence(u16::MAX));
self.sequence += Sequence(1);
probe
}
#[instrument(skip(self))]
pub fn fail_probe(&mut self) {
let probe_index = usize::from(self.sequence - self.round_sequence);
let probe = self.buffer[probe_index - 1].clone();
match probe {
ProbeStatus::Awaited(awaited) => {
self.buffer[probe_index - 1] = ProbeStatus::Failed(awaited.failed());
}
_ => unreachable!("expected ProbeStatus::Awaited"),
}
}
fn probe_data(&self) -> (Port, Port, TraceId, Flags) {
match self.config.protocol {
Protocol::Icmp => self.probe_icmp_data(),
Protocol::Udp => self.probe_udp_data(),
Protocol::Tcp => self.probe_tcp_data(),
}
}
const fn probe_icmp_data(&self) -> (Port, Port, TraceId, Flags) {
(
Port(0),
Port(0),
self.config.trace_identifier,
Flags::empty(),
)
}
fn probe_udp_data(&self) -> (Port, Port, TraceId, Flags) {
match self.config.multipath_strategy {
MultipathStrategy::Classic => match self.config.port_direction {
PortDirection::FixedSrc(src_port) => (
Port(src_port.0),
Port(self.sequence.0),
TraceId(0),
Flags::empty(),
),
PortDirection::FixedDest(dest_port) => (
Port(self.sequence.0),
Port(dest_port.0),
TraceId(0),
Flags::empty(),
),
PortDirection::FixedBoth(_, _) | PortDirection::None => {
unimplemented!()
}
},
MultipathStrategy::Paris => {
let round_port = ((self.config.initial_sequence.0 as usize + self.round.0)
% usize::from(u16::MAX)) as u16;
match self.config.port_direction {
PortDirection::FixedSrc(src_port) => (
Port(src_port.0),
Port(round_port),
TraceId(0),
Flags::PARIS_CHECKSUM,
),
PortDirection::FixedDest(dest_port) => (
Port(round_port),
Port(dest_port.0),
TraceId(0),
Flags::PARIS_CHECKSUM,
),
PortDirection::FixedBoth(src_port, dest_port) => (
Port(src_port.0),
Port(dest_port.0),
TraceId(0),
Flags::PARIS_CHECKSUM,
),
PortDirection::None => unimplemented!(),
}
}
MultipathStrategy::Dublin => {
let round_port = ((self.config.initial_sequence.0 as usize + self.round.0)
% usize::from(u16::MAX)) as u16;
match self.config.port_direction {
PortDirection::FixedSrc(src_port) => (
Port(src_port.0),
Port(round_port),
TraceId(self.sequence.0),
Flags::DUBLIN_IPV6_PAYLOAD_LENGTH,
),
PortDirection::FixedDest(dest_port) => (
Port(round_port),
Port(dest_port.0),
TraceId(self.sequence.0),
Flags::DUBLIN_IPV6_PAYLOAD_LENGTH,
),
PortDirection::FixedBoth(src_port, dest_port) => (
Port(src_port.0),
Port(dest_port.0),
TraceId(self.sequence.0),
Flags::DUBLIN_IPV6_PAYLOAD_LENGTH,
),
PortDirection::None => unimplemented!(),
}
}
}
}
fn probe_tcp_data(&self) -> (Port, Port, TraceId, Flags) {
let (src_port, dest_port) = match self.config.port_direction {
PortDirection::FixedSrc(src_port) => (src_port.0, self.sequence.0),
PortDirection::FixedDest(dest_port) => (self.sequence.0, dest_port.0),
PortDirection::FixedBoth(_, _) | PortDirection::None => unimplemented!(),
};
(Port(src_port), Port(dest_port), TraceId(0), Flags::empty())
}
#[instrument(skip(self))]
pub fn complete_probe(&mut self, resp: StrategyResponse) {
let probe = self.probe_at(resp.sequence);
let awaited = match probe {
ProbeStatus::Awaited(awaited) => awaited,
ProbeStatus::Complete(_) => {
return;
}
_ => {
debug_assert!(
false,
"completed probe was not in Awaited state (probe={probe:#?})"
);
return;
}
};
let completed = awaited.complete(
resp.addr,
resp.received,
resp.icmp_packet_type,
resp.expected_udp_checksum,
resp.actual_udp_checksum,
resp.exts,
);
let ttl = completed.ttl;
self.buffer[usize::from(resp.sequence - self.round_sequence)] =
ProbeStatus::Complete(completed);
self.target_ttl = if resp.is_target {
match self.target_ttl {
None => Some(ttl),
Some(target_ttl) if ttl < target_ttl => Some(ttl),
Some(target_ttl) => Some(target_ttl),
}
} else {
match self.target_ttl {
Some(target_ttl) if ttl >= target_ttl => None,
Some(target_ttl) => Some(target_ttl),
None => None,
}
};
self.max_received_ttl = match self.max_received_ttl {
None => Some(ttl),
Some(max_received_ttl) => Some(max_received_ttl.max(ttl)),
};
self.received_time = Some(resp.received);
self.target_found |= resp.is_target;
}
#[instrument(skip(self))]
pub fn advance_round(&mut self, first_ttl: TimeToLive) {
if self.sequence >= self.max_sequence() {
self.sequence = self.config.initial_sequence;
}
self.target_found = false;
self.round_sequence = self.sequence;
self.received_time = None;
self.round_start = SystemTime::now();
self.max_received_ttl = None;
self.round += RoundId(1);
self.ttl = first_ttl;
}
fn max_sequence(&self) -> Sequence {
match (self.config.multipath_strategy, self.config.target_addr) {
(MultipathStrategy::Dublin, IpAddr::V6(_)) => {
self.config.initial_sequence + Sequence(BUFFER_SIZE)
}
_ => MAX_SEQUENCE,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::probe::{IcmpPacketCode, IcmpPacketType};
use crate::types::MaxInflight;
use rand::Rng;
use std::net::{IpAddr, Ipv4Addr};
use std::time::Duration;
#[allow(
clippy::cognitive_complexity,
clippy::too_many_lines,
clippy::bool_assert_comparison
)]
#[test]
fn test_state() {
let mut state = TracerState::new(cfg(Sequence(33434)));
assert_eq!(state.round, RoundId(0));
assert_eq!(state.sequence, Sequence(33434));
assert_eq!(state.round_sequence, Sequence(33434));
assert_eq!(state.ttl, TimeToLive(1));
assert_eq!(state.max_received_ttl, None);
assert_eq!(state.received_time, None);
assert_eq!(state.target_ttl, None);
assert_eq!(state.target_found, false);
let prob_init = state.probe_at(Sequence(33434));
assert_eq!(ProbeStatus::NotSent, prob_init);
let sent_1 = SystemTime::now();
let probe_1 = state.next_probe(sent_1);
assert_eq!(probe_1.sequence, Sequence(33434));
assert_eq!(probe_1.ttl, TimeToLive(1));
assert_eq!(probe_1.round, RoundId(0));
assert_eq!(probe_1.sent, sent_1);
let received_1 = SystemTime::now();
let host = IpAddr::V4(Ipv4Addr::LOCALHOST);
state.complete_probe(StrategyResponse {
icmp_packet_type: IcmpPacketType::TimeExceeded(IcmpPacketCode(1)),
trace_id: TraceId(0),
sequence: Sequence(33434),
expected_udp_checksum: None,
actual_udp_checksum: None,
received: received_1,
addr: host,
is_target: false,
exts: None,
});
let probe_1_fetch = state.probe_at(Sequence(33434)).try_into_complete().unwrap();
assert_eq!(probe_1_fetch.sequence, Sequence(33434));
assert_eq!(probe_1_fetch.ttl, TimeToLive(1));
assert_eq!(probe_1_fetch.round, RoundId(0));
assert_eq!(probe_1_fetch.received, received_1);
assert_eq!(probe_1_fetch.host, host);
assert_eq!(probe_1_fetch.sent, sent_1);
assert_eq!(
probe_1_fetch.icmp_packet_type,
IcmpPacketType::TimeExceeded(IcmpPacketCode(1))
);
assert_eq!(state.round, RoundId(0));
assert_eq!(state.sequence, Sequence(33435));
assert_eq!(state.round_sequence, Sequence(33434));
assert_eq!(state.ttl, TimeToLive(2));
assert_eq!(state.max_received_ttl, Some(TimeToLive(1)));
assert_eq!(state.received_time, Some(received_1));
assert_eq!(state.target_ttl, None);
assert_eq!(state.target_found, false);
{
let mut probe_iter = state.probes().iter();
let probe_next1 = probe_iter.next().unwrap();
assert_eq!(ProbeStatus::Complete(probe_1_fetch), probe_next1.clone());
assert_eq!(None, probe_iter.next());
}
state.advance_round(TimeToLive(1));
assert_eq!(state.round, RoundId(1));
assert_eq!(state.sequence, Sequence(33435));
assert_eq!(state.round_sequence, Sequence(33435));
assert_eq!(state.ttl, TimeToLive(1));
assert_eq!(state.max_received_ttl, None);
assert_eq!(state.received_time, None);
assert_eq!(state.target_ttl, None);
assert_eq!(state.target_found, false);
let sent_2 = SystemTime::now();
let probe_2 = state.next_probe(sent_2);
assert_eq!(probe_2.sequence, Sequence(33435));
assert_eq!(probe_2.ttl, TimeToLive(1));
assert_eq!(probe_2.round, RoundId(1));
assert_eq!(probe_2.sent, sent_2);
let sent_3 = SystemTime::now();
let probe_3 = state.next_probe(sent_3);
assert_eq!(probe_3.sequence, Sequence(33436));
assert_eq!(probe_3.ttl, TimeToLive(2));
assert_eq!(probe_3.round, RoundId(1));
assert_eq!(probe_3.sent, sent_3);
let received_2 = SystemTime::now();
let host = IpAddr::V4(Ipv4Addr::LOCALHOST);
state.complete_probe(StrategyResponse {
icmp_packet_type: IcmpPacketType::TimeExceeded(IcmpPacketCode(1)),
trace_id: TraceId(0),
sequence: Sequence(33435),
expected_udp_checksum: None,
actual_udp_checksum: None,
received: received_2,
addr: host,
is_target: false,
exts: None,
});
let probe_2_recv = state.probe_at(Sequence(33435));
assert_eq!(state.round, RoundId(1));
assert_eq!(state.sequence, Sequence(33437));
assert_eq!(state.round_sequence, Sequence(33435));
assert_eq!(state.ttl, TimeToLive(3));
assert_eq!(state.max_received_ttl, Some(TimeToLive(1)));
assert_eq!(state.received_time, Some(received_2));
assert_eq!(state.target_ttl, None);
assert_eq!(state.target_found, false);
{
let mut probe_iter = state.probes().iter();
let probe_next1 = probe_iter.next().unwrap();
assert_eq!(&probe_2_recv, probe_next1);
let probe_next2 = probe_iter.next().unwrap();
assert_eq!(ProbeStatus::Awaited(probe_3), probe_next2.clone());
}
let received_3 = SystemTime::now();
let host = IpAddr::V4(Ipv4Addr::LOCALHOST);
state.complete_probe(StrategyResponse {
icmp_packet_type: IcmpPacketType::EchoReply(IcmpPacketCode(0)),
trace_id: TraceId(0),
sequence: Sequence(33436),
expected_udp_checksum: None,
actual_udp_checksum: None,
received: received_3,
addr: host,
is_target: true,
exts: None,
});
let probe_3_recv = state.probe_at(Sequence(33436));
assert_eq!(state.round, RoundId(1));
assert_eq!(state.sequence, Sequence(33437));
assert_eq!(state.round_sequence, Sequence(33435));
assert_eq!(state.ttl, TimeToLive(3));
assert_eq!(state.max_received_ttl, Some(TimeToLive(2)));
assert_eq!(state.received_time, Some(received_3));
assert_eq!(state.target_ttl, Some(TimeToLive(2)));
assert_eq!(state.target_found, true);
{
let mut probe_iter = state.probes().iter();
let probe_next1 = probe_iter.next().unwrap();
assert_eq!(&probe_2_recv, probe_next1);
let probe_next2 = probe_iter.next().unwrap();
assert_eq!(&probe_3_recv, probe_next2);
}
}
#[test]
fn test_sequence_wrap1() {
let initial_sequence = Sequence(65278);
let mut state = TracerState::new(cfg(initial_sequence));
assert_eq!(state.round, RoundId(0));
assert_eq!(state.sequence, initial_sequence);
assert_eq!(state.round_sequence, initial_sequence);
assert_eq!(
state.next_probe(SystemTime::now()).sequence,
Sequence(65278)
);
assert_eq!(state.sequence, Sequence(65279));
{
let mut iter = state.probes().iter();
assert_eq!(
iter.next()
.unwrap()
.clone()
.try_into_awaited()
.unwrap()
.sequence,
Sequence(65278)
);
iter.take(BUFFER_SIZE as usize - 1)
.for_each(|p| assert!(matches!(p, ProbeStatus::NotSent)));
}
state.advance_round(TimeToLive(1));
assert_eq!(state.round, RoundId(1));
assert_eq!(state.sequence, initial_sequence);
assert_eq!(state.round_sequence, initial_sequence);
assert_eq!(
state.next_probe(SystemTime::now()).sequence,
Sequence(65278)
);
assert_eq!(state.sequence, Sequence(65279));
{
let mut iter = state.probes().iter();
assert_eq!(
iter.next()
.unwrap()
.clone()
.try_into_awaited()
.unwrap()
.sequence,
Sequence(65278)
);
iter.take(BUFFER_SIZE as usize - 1)
.for_each(|p| assert!(matches!(p, ProbeStatus::NotSent)));
}
}
#[test]
fn test_sequence_wrap2() {
let total_rounds = 2000;
let max_probe_per_round = 254;
let mut state = TracerState::new(cfg(Sequence(33434)));
for _ in 0..total_rounds {
for _ in 0..max_probe_per_round {
let _probe = state.next_probe(SystemTime::now());
}
state.advance_round(TimeToLive(1));
}
assert_eq!(state.round, RoundId(2000));
assert_eq!(state.round_sequence, Sequence(33434));
assert_eq!(state.sequence, Sequence(33434));
}
#[test]
fn test_sequence_wrap3() {
let total_rounds = 2000;
let max_probe_per_round = 20;
let mut state = TracerState::new(cfg(Sequence(33434)));
let mut rng = rand::thread_rng();
for _ in 0..total_rounds {
for _ in 0..rng.gen_range(0..max_probe_per_round) {
state.next_probe(SystemTime::now());
}
state.advance_round(TimeToLive(1));
}
}
#[test]
fn test_sequence_wrap_with_skip() {
let total_rounds = 2000;
let max_probe_per_round = 254;
let mut state = TracerState::new(cfg(Sequence(33434)));
for _ in 0..total_rounds {
for _ in 0..max_probe_per_round {
_ = state.next_probe(SystemTime::now());
_ = state.reissue_probe(SystemTime::now());
}
state.advance_round(TimeToLive(1));
}
assert_eq!(state.round, RoundId(2000));
assert_eq!(state.round_sequence, Sequence(57310));
assert_eq!(state.sequence, Sequence(57310));
}
#[test]
fn test_in_round() {
let state = TracerState::new(cfg(Sequence(33434)));
assert!(state.in_round(Sequence(33434)));
assert!(state.in_round(Sequence(33945)));
assert!(!state.in_round(Sequence(33946)));
}
#[test]
#[should_panic(expected = "assertion failed: !state.in_round(Sequence(64491))")]
fn test_in_delayed_probe_not_in_round() {
let mut state = TracerState::new(cfg(Sequence(64000)));
for _ in 0..55 {
_ = state.next_probe(SystemTime::now());
}
state.advance_round(TimeToLive(1));
assert!(!state.in_round(Sequence(64491)));
}
fn cfg(initial_sequence: Sequence) -> StrategyConfig {
StrategyConfig {
target_addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED),
protocol: Protocol::Icmp,
trace_identifier: TraceId::default(),
max_rounds: None,
first_ttl: TimeToLive(1),
max_ttl: TimeToLive(24),
grace_duration: Duration::default(),
max_inflight: MaxInflight::default(),
initial_sequence,
multipath_strategy: MultipathStrategy::Classic,
port_direction: PortDirection::None,
min_round_duration: Duration::default(),
max_round_duration: Duration::default(),
}
}
}
}
fn exceeds(start: Option<SystemTime>, end: SystemTime, dur: Duration) -> bool {
start.is_some_and(|start| end.duration_since(start).unwrap_or_default() > dur)
}