use crate::net::{
self,
socket::{SocketHandle, SocketSet, UdpSocket, UdpSocketBuffer},
time::{Duration, Instant},
wire::{IpAddress, IpEndpoint},
Error,
};
use crate::wire::tftp::*;
use managed::ManagedSlice;
const MAX_RETRIES: u8 = 10;
const RETRY_TIMEOUT: Duration = Duration { millis: 200 };
const TFTP_PORT: u16 = 69;
pub trait Context {
type Handle: Handle;
fn open(&mut self, filename: &str, write_mode: bool) -> Result<Self::Handle, ()>;
fn close(&mut self, handle: Self::Handle);
}
pub trait Handle {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, ()>;
fn write(&mut self, buf: &[u8]) -> Result<usize, ()>;
}
pub struct Server {
udp_handle: SocketHandle,
next_poll: Instant,
}
impl Server {
pub fn new<'a, 'b, 'c>(
sockets: &mut SocketSet<'a, 'b, 'c>,
rx_buffer: UdpSocketBuffer<'b, 'c>,
tx_buffer: UdpSocketBuffer<'b, 'c>,
now: Instant,
) -> Self {
let socket = UdpSocket::new(rx_buffer, tx_buffer);
let udp_handle = sockets.add(socket);
net_trace!("TFTP initialised");
Server {
udp_handle,
next_poll: now,
}
}
pub fn next_poll(&self, now: Instant) -> Duration {
self.next_poll - now
}
pub fn serve<'a, C>(
&mut self,
sockets: &mut SocketSet,
context: &mut C,
transfers: &mut ManagedSlice<'a, Option<Transfer<C::Handle>>>,
now: Instant,
) -> net::Result<()>
where
C: Context,
{
let mut socket = sockets.get::<UdpSocket>(self.udp_handle);
if !socket.is_open() {
socket.bind(IpEndpoint {
addr: IpAddress::Unspecified,
port: TFTP_PORT,
})?;
}
self.next_poll = now + Duration::from_millis(50);
match socket.recv() {
Ok((data, ep)) => {
let tftp_packet = match Packet::new_checked(data) {
Ok(tftp_packet) => tftp_packet,
Err(_) => {
send_error(
&mut *socket,
ep,
ErrorCode::AccessViolation,
"Packet truncated",
)?;
return Ok(());
}
};
let tftp_repr = match Repr::parse(&tftp_packet) {
Ok(tftp_repr) => tftp_repr,
Err(_) => {
return send_error(
&mut *socket,
ep,
ErrorCode::AccessViolation,
"Malformed packet",
);
}
};
let xfer_idx = transfers.iter_mut().position(|xfer| {
if let Some(xfer) = xfer {
if xfer.ep == ep {
return true;
}
}
false
});
let is_write = tftp_packet.opcode() == OpCode::Write;
match (tftp_repr, xfer_idx) {
(Repr::ReadRequest { .. }, Some(_)) | (Repr::WriteRequest { .. }, Some(_)) => {
net_debug!("tftp: multiple connection attempts from {}", ep);
return send_error(
&mut *socket,
ep,
ErrorCode::AccessViolation,
"Multiple connections not supported",
);
}
(Repr::ReadRequest { filename, mode, .. }, None)
| (Repr::WriteRequest { filename, mode, .. }, None) => {
if mode != Mode::Octet {
return send_error(
&mut *socket,
ep,
ErrorCode::IllegalOperation,
"Only octet mode is supported",
);
}
let opt_idx =
transfers.iter().position(|t| t.is_none()).or_else(
|| match transfers {
ManagedSlice::Borrowed(_) => None,
ManagedSlice::Owned(v) => {
let idx = v.len();
v.push(None);
Some(idx)
}
},
);
if let Some(idx) = opt_idx {
let handle = match context.open(filename, is_write) {
Ok(handle) => handle,
Err(_) => {
net_debug!("tftp: unable to open requested file");
return send_error(
&mut *socket,
ep,
ErrorCode::FileNotFound,
"Unable to open requested file",
);
}
};
let mut xfer = Transfer {
handle,
ep,
is_write,
block_num: 1,
last_data: None,
last_len: 0,
retries: 0,
timeout: now + Duration::from_millis(50),
};
net_debug!(
"tftp: {} request from {}",
if is_write { "write" } else { "read" },
ep
);
if is_write {
xfer.send_ack(&mut *socket, 0)?;
} else {
xfer.send_data(&mut *socket)?;
}
transfers[idx] = Some(xfer);
} else {
net_debug!("tftp: connections exhausted");
return send_error(
&mut *socket,
ep,
ErrorCode::AccessViolation,
"No more available connections",
);
}
}
(Repr::Data { .. }, None) | (Repr::Ack { .. }, None) => {
return send_error(
&mut *socket,
ep,
ErrorCode::AccessViolation,
"Data packet without active transfer",
);
}
(Repr::Data { block_num, data }, Some(idx)) => {
let xfer = transfers[idx].as_mut().unwrap();
xfer.timeout = now + RETRY_TIMEOUT;
xfer.retries = 0;
if !xfer.is_write {
return send_error(
&mut *socket,
ep,
ErrorCode::AccessViolation,
"Not a write connection",
);
}
if block_num != xfer.block_num {
return xfer.send_ack(&mut *socket, xfer.block_num - 1);
}
xfer.block_num += 1;
match xfer.handle.write(data) {
Ok(_) => {
let last_block = data.len() < 512;
xfer.send_ack(&mut *socket, block_num)?;
if last_block {
self.close_transfer(context, &mut transfers[idx]);
}
}
Err(_) => {
send_error(
&mut *socket,
ep,
ErrorCode::AccessViolation,
"Error writing file",
)?;
self.close_transfer(context, &mut transfers[idx]);
}
}
}
(Repr::Ack { block_num }, Some(idx)) => {
let xfer = transfers[idx].as_mut().unwrap();
xfer.timeout = now + RETRY_TIMEOUT;
xfer.retries = 0;
if xfer.is_write {
return send_error(
&mut *socket,
ep,
ErrorCode::AccessViolation,
"Not a read connection",
);
}
if block_num != xfer.block_num {
return xfer.resend_data(&mut *socket);
}
xfer.block_num += 1;
if xfer.last_len == 512 {
xfer.send_data(&mut *socket)?;
} else {
self.close_transfer(context, &mut transfers[idx]);
}
}
(Repr::Error { .. }, _) => {
return send_error(
&mut *socket,
ep,
ErrorCode::IllegalOperation,
"Unknown operation",
);
}
}
Ok(())
}
Err(Error::Exhausted) => {
if socket.can_send() && now >= self.next_poll {
for xfer in transfers.iter_mut() {
let do_drop = if let Some(xfer) = xfer {
xfer.process_timeout(&mut socket, now)?
} else {
false
};
if do_drop {
self.close_transfer(context, xfer);
}
}
}
Ok(())
}
Err(e) => return Err(e),
}
}
fn close_transfer<C>(&mut self, context: &mut C, xfer: &mut Option<Transfer<C::Handle>>)
where
C: Context,
{
if let Some(xfer) = xfer.take() {
net_debug!("tftp: closing {}", xfer.ep);
context.close(xfer.handle);
}
}
}
pub struct Transfer<H> {
handle: H,
ep: IpEndpoint,
is_write: bool,
block_num: u16,
last_data: Option<[u8; 512]>,
last_len: usize,
retries: u8,
timeout: Instant,
}
impl<H> Transfer<H>
where
H: Handle,
{
fn process_timeout(&mut self, socket: &mut UdpSocket, now: Instant) -> net::Result<bool> {
if now >= self.timeout && self.retries < MAX_RETRIES {
self.retries += 1;
self.resend_data(socket).map(|_| false)
} else {
net_debug!("tftp: connection timeout");
Ok(true)
}
}
fn send_data(&mut self, socket: &mut UdpSocket) -> net::Result<bool> {
if self.last_data.is_none() {
self.last_data = Some([0; 512]);
}
self.last_len = match self.handle.read(&mut self.last_data.as_mut().unwrap()[..]) {
Ok(n) => n,
Err(_) => {
send_error(
socket,
self.ep,
ErrorCode::AccessViolation,
"Error occurred while reading the file",
)?;
return Ok(false);
}
};
self.resend_data(socket).map(|_| false)
}
fn resend_data(&mut self, socket: &mut UdpSocket) -> net::Result<()> {
if let Some(last_data) = &self.last_data {
net_trace!("tftp: sending data block #{}", self.block_num);
let data = Repr::Data {
block_num: self.block_num,
data: &last_data[..self.last_len],
};
let payload = socket.send(data.buffer_len(), self.ep)?;
let mut pkt = Packet::new_unchecked(payload);
data.emit(&mut pkt)?;
}
Ok(())
}
fn send_ack(&mut self, socket: &mut UdpSocket, block: u16) -> net::Result<()> {
net_trace!("tftp: sending ack #{}", block);
let ack = Repr::Ack { block_num: block };
let payload = socket.send(ack.buffer_len(), self.ep)?;
let mut pkt = Packet::new_unchecked(payload);
ack.emit(&mut pkt)
}
}
fn send_error(
socket: &mut UdpSocket,
ep: IpEndpoint,
code: ErrorCode,
msg: &str,
) -> net::Result<()> {
net_debug!("tftp: {:?}, message: {}", code, msg);
let err = Repr::Error { code, msg };
let payload = socket.send(err.buffer_len(), ep)?;
let mut pkt = Packet::new_unchecked(payload);
err.emit(&mut pkt)
}