use std::borrow::Borrow;
use std::fmt::{self, Display};
use std::marker::PhantomData;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use futures::{Future, Stream};
use log::{debug, warn};
use tokio::time::Elapsed;
use crate::error::ProtoError;
use crate::op::message::NoopMessageFinalizer;
use crate::op::{MessageFinalizer, OpCode};
use crate::udp::udp_stream::{NextRandomUdpSocket, UdpSocket};
use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, SerialMessage};
#[must_use = "futures do nothing unless polled"]
pub struct UdpClientStream<S, MF = NoopMessageFinalizer>
where
S: Send,
MF: MessageFinalizer,
{
name_server: SocketAddr,
timeout: Duration,
is_shutdown: bool,
signer: Option<Arc<MF>>,
marker: PhantomData<S>,
}
impl<S: Send> UdpClientStream<S, NoopMessageFinalizer> {
#[allow(clippy::new_ret_no_self)]
pub fn new(name_server: SocketAddr) -> UdpClientConnect<S, NoopMessageFinalizer> {
Self::with_timeout(name_server, Duration::from_secs(5))
}
pub fn with_timeout(
name_server: SocketAddr,
timeout: Duration,
) -> UdpClientConnect<S, NoopMessageFinalizer> {
Self::with_timeout_and_signer(name_server, timeout, None)
}
}
impl<S: Send, MF: MessageFinalizer> UdpClientStream<S, MF> {
pub fn with_timeout_and_signer(
name_server: SocketAddr,
timeout: Duration,
signer: Option<Arc<MF>>,
) -> UdpClientConnect<S, MF> {
UdpClientConnect {
name_server: Some(name_server),
timeout,
signer,
marker: PhantomData::<S>,
}
}
}
impl<S: Send, MF: MessageFinalizer> Display for UdpClientStream<S, MF> {
fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> {
write!(formatter, "UDP({})", self.name_server)
}
}
fn random_query_id() -> u16 {
use rand::distributions::{Distribution, Standard};
let mut rand = rand::thread_rng();
Standard.sample(&mut rand)
}
impl<S: UdpSocket + Send + 'static, MF: MessageFinalizer> DnsRequestSender
for UdpClientStream<S, MF>
{
type DnsResponseFuture = UdpResponse;
fn send_message(
&mut self,
mut message: DnsRequest,
_cx: &mut Context,
) -> Self::DnsResponseFuture {
if self.is_shutdown {
panic!("can not send messages after stream is shutdown")
}
message.set_id(random_query_id());
let now = match SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|_| ProtoError::from("Current time is before the Unix epoch."))
{
Ok(now) => now.as_secs(),
Err(err) => {
let err: ProtoError = err;
return UdpResponse::complete(SingleUseUdpSocket::errored(err));
}
};
let now = now as u32;
if let OpCode::Update = message.op_code() {
if let Some(ref signer) = self.signer {
if let Err(e) = message.finalize::<MF>(signer.borrow(), now) {
debug!("could not sign message: {}", e);
return UdpResponse::complete(SingleUseUdpSocket::errored(e));
}
}
}
let bytes = match message.to_vec() {
Ok(bytes) => bytes,
Err(err) => {
return UdpResponse::complete(SingleUseUdpSocket::errored(err));
}
};
let message_id = message.id();
let message = SerialMessage::new(bytes, self.name_server);
UdpResponse::new::<S>(message, message_id, self.timeout)
}
fn error_response(err: ProtoError) -> Self::DnsResponseFuture {
UdpResponse::complete(SingleUseUdpSocket::errored(err))
}
fn shutdown(&mut self) {
self.is_shutdown = true;
}
fn is_shutdown(&self) -> bool {
self.is_shutdown
}
}
impl<S: Send, MF: MessageFinalizer> Stream for UdpClientStream<S, MF> {
type Item = Result<(), ProtoError>;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Option<Self::Item>> {
if self.is_shutdown {
Poll::Ready(None)
} else {
Poll::Ready(Some(Ok(())))
}
}
}
#[allow(clippy::type_complexity)]
pub struct UdpResponse(
Pin<Box<dyn Future<Output = Result<Result<DnsResponse, ProtoError>, Elapsed>> + Send>>,
);
impl UdpResponse {
fn new<S: UdpSocket + Send + Unpin + 'static>(
request: SerialMessage,
message_id: u16,
timeout: Duration,
) -> Self {
UdpResponse(Box::pin(tokio::time::timeout(
timeout,
SingleUseUdpSocket::send_serial_message::<S>(request, message_id),
)))
}
fn complete<F: Future<Output = Result<DnsResponse, ProtoError>> + Send + 'static>(
f: F,
) -> Self {
UdpResponse(Box::pin(tokio::time::timeout(Duration::from_secs(5), f)))
}
}
impl Future for UdpResponse {
type Output = Result<DnsResponse, ProtoError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.0
.as_mut()
.poll(cx)
.map_err(ProtoError::from)
.map(|r| r.and_then(|r| r))
}
}
pub struct UdpClientConnect<S, MF = NoopMessageFinalizer>
where
S: Send,
MF: MessageFinalizer,
{
name_server: Option<SocketAddr>,
timeout: Duration,
signer: Option<Arc<MF>>,
marker: PhantomData<S>,
}
impl<S: Send + Unpin, MF: MessageFinalizer> Future for UdpClientConnect<S, MF> {
type Output = Result<UdpClientStream<S, MF>, ProtoError>;
fn poll(mut self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
Poll::Ready(Ok(UdpClientStream::<S, MF> {
name_server: self
.name_server
.take()
.expect("UdpClientConnect invalid state: name_server"),
is_shutdown: false,
timeout: self.timeout,
signer: self.signer.take(),
marker: PhantomData,
}))
}
}
struct SingleUseUdpSocket;
impl SingleUseUdpSocket {
async fn send_serial_message<S: UdpSocket + Send>(
msg: SerialMessage,
msg_id: u16,
) -> Result<DnsResponse, ProtoError> {
let name_server = msg.addr();
let mut socket: S = NextRandomUdpSocket::new(&name_server).await?;
let bytes = msg.bytes();
let addr = &msg.addr();
let len_sent: usize = socket.send_to(bytes, addr).await?;
if bytes.len() != len_sent {
return Err(ProtoError::from(format!(
"Not all bytes of message sent, {} of {}",
len_sent,
bytes.len()
)));
}
loop {
let mut recv_buf = [0u8; 2048];
let (len, src) = socket.recv_from(&mut recv_buf).await?;
let response = SerialMessage::new(recv_buf.iter().take(len).cloned().collect(), src);
let request_target = msg.addr();
if response.addr() != request_target {
warn!(
"ignoring response from {} because it does not match name_server: {}.",
response.addr(),
request_target,
);
continue;
}
match response.to_message() {
Ok(message) => {
if msg_id == message.id() {
debug!("received message id: {}", message.id());
return Ok(DnsResponse::from(message));
} else {
warn!(
"expected message id: {} got: {}, dropped",
msg_id,
message.id()
);
continue;
}
}
Err(e) => {
warn!(
"dropped malformed message waiting for id: {} err: {}",
msg_id, e
);
continue;
}
}
}
}
async fn errored(err: ProtoError) -> Result<DnsResponse, ProtoError> {
futures::future::err(err).await
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::dbg_macro, clippy::print_stdout)]
use futures::future;
#[cfg(not(target_os = "linux"))]
use std::net::Ipv6Addr;
use std::net::{IpAddr, Ipv4Addr};
use tokio;
use super::*;
use crate::op::Message;
#[test]
fn test_udp_client_stream_ipv4() {
udp_client_stream_test(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)))
}
#[test]
#[cfg(not(target_os = "linux"))]
fn test_udp_client_stream_ipv6() {
udp_client_stream_test(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)))
}
fn udp_client_stream_test(server_addr: IpAddr) {
use crate::op::Query;
use crate::rr::rdata::NULL;
use crate::rr::{Name, RData, Record, RecordType};
use std::str::FromStr;
use tokio::runtime;
let succeeded = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let succeeded_clone = succeeded.clone();
std::thread::Builder::new()
.name("thread_killer".to_string())
.spawn(move || {
let succeeded = succeeded_clone;
for _ in 0..15 {
std::thread::sleep(std::time::Duration::from_secs(1));
if succeeded.load(std::sync::atomic::Ordering::Relaxed) {
return;
}
}
panic!("timeout");
})
.unwrap();
let server = std::net::UdpSocket::bind(SocketAddr::new(server_addr, 0)).unwrap();
server
.set_read_timeout(Some(std::time::Duration::from_secs(5)))
.unwrap();
server
.set_write_timeout(Some(std::time::Duration::from_secs(5)))
.unwrap();
let server_addr = server.local_addr().unwrap();
let mut query = Message::new();
let test_name = Name::from_str("dead.beef").unwrap();
query.add_query(Query::query(test_name.clone(), RecordType::NULL));
let test_bytes: &'static [u8; 8] = b"DEADBEEF";
let send_recv_times = 4;
let test_name_server = test_name;
let server_handle = std::thread::Builder::new()
.name("test_udp_client_stream_ipv4:server".to_string())
.spawn(move || {
let mut buffer = [0_u8; 512];
for i in 0..send_recv_times {
debug!("server receiving request {}", i);
let (len, addr) = server.recv_from(&mut buffer).expect("receive failed");
debug!("server received request {} from: {}", i, addr);
let request =
Message::from_vec(&buffer[0..len]).expect("failed parse of request");
assert_eq!(*request.queries()[0].name(), test_name_server.clone());
assert_eq!(request.queries()[0].query_type(), RecordType::NULL);
let mut message = Message::new();
message.set_id(request.id());
message.add_queries(request.queries().to_vec());
message.add_answer(Record::from_rdata(
test_name_server.clone(),
0,
RData::NULL(NULL::with(test_bytes.to_vec())),
));
let bytes = message.to_vec().unwrap();
debug!("server sending response {} to: {}", i, addr);
assert_eq!(
server.send_to(&bytes, addr).expect("send failed"),
bytes.len()
);
debug!("server sent response {}", i);
std::thread::yield_now();
}
})
.unwrap();
let mut io_loop = runtime::Runtime::new().unwrap();
let stream = UdpClientStream::with_timeout(server_addr, Duration::from_millis(500));
let mut stream: UdpClientStream<tokio::net::UdpSocket> =
io_loop.block_on(stream).ok().unwrap();
let mut worked_once = false;
for i in 0..send_recv_times {
let response_future = io_loop.block_on(future::lazy(|cx| {
stream.send_message(DnsRequest::new(query.clone(), Default::default()), cx)
}));
println!("client sending request {}", i);
let response = match io_loop.block_on(response_future) {
Ok(response) => response,
Err(err) => {
println!("failed to get message: {}", err);
continue;
}
};
println!("client got response {}", i);
let response = Message::from(response);
if let RData::NULL(null) = response.answers()[0].rdata() {
assert_eq!(null.anything().expect("no bytes in NULL"), test_bytes);
} else {
panic!("not a NULL response");
}
worked_once = true;
}
succeeded.store(true, std::sync::atomic::Ordering::Relaxed);
server_handle.join().expect("server thread failed");
assert!(worked_once);
}
}