1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
use {WireDecoder, WireEncoder, WireMessage, std, wire};
use std::net::UdpSocket;
use std::sync::Arc;
#[derive(Debug)]
pub enum ServerError {
Io { inner: std::io::Error, what: String },
}
impl std::error::Error for ServerError {
fn description(&self) -> &str {
match self {
&ServerError::Io { ref what, .. } => what,
}
}
}
impl std::fmt::Display for ServerError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
let d = (self as &std::error::Error).description();
match self {
&ServerError::Io { ref inner, .. } => write!(f, "{}: {}", d, inner),
}
}
}
pub trait Handler {
type Error: std::error::Error;
fn handle_query<'a>(&self,
query: &WireMessage,
encoder: WireEncoder<'a, wire::marker::Response, wire::marker::AnswerSection>)
-> WireEncoder<'a, wire::marker::Response, wire::marker::Done>;
}
#[derive(Debug)]
pub struct Server<'a, H: 'a + Handler> {
socket: Arc<UdpSocket>,
handler: &'a H,
}
impl<'a, H: Handler> Server<'a, H> {
pub fn new(h: &'a H) -> Result<Self, ServerError> {
let addr = "0.0.0.0:53";
let socket = UdpSocket::bind(addr).map_err(|e| {
ServerError::Io {
inner: e,
what: format!("Failed to bind UDP socket to {}", addr),
}
})?;
Ok(Server {
socket: Arc::new(socket),
handler: h,
})
}
pub fn serve(self) -> Result<(), ServerError> {
const MAX_UDP_MESSAGE_LEN: usize = 512;
let mut ibuffer: [u8; MAX_UDP_MESSAGE_LEN] = [0; MAX_UDP_MESSAGE_LEN];
loop {
let (recv_len, peer_addr) = self.socket
.recv_from(&mut ibuffer)
.map_err(|e| {
ServerError::Io {
inner: e,
what: String::from("Failed to receive from UDP socket"),
}
})?;
let ipayload = &ibuffer[..recv_len];
let mut decoder = WireDecoder::new(ipayload);
let request = match decoder.decode_message() {
Ok(x) => x,
Err(e) => {
println!("Received invalid message: {}", e);
continue;
}
};
let mut obuffer: [u8; MAX_UDP_MESSAGE_LEN] = [0; MAX_UDP_MESSAGE_LEN];
let encoder = match WireEncoder::new_response(&mut obuffer[..], &request) {
Ok(x) => x,
Err(_) => continue,
};
let encoder = self.handler.handle_query(&request, encoder);
let opayload = encoder.as_bytes();
match self.socket.send_to(opayload, peer_addr) {
Ok(send_len) => {
if send_len != opayload.len() {
println!("Sent unexpected number of bytes on UDP socket: Expected to send {}, actually sent \
{}",
opayload.len(),
send_len);
}
}
Err(e) => {
println!("Failed to send on UDP socket: {}", e);
}
}
}
}
}