use std::{
collections::VecDeque,
io::{self, Write},
};
use bytes::{Bytes, BytesMut};
use futures::{
sync::mpsc::{Receiver, Sender},
task, Async, Poll, Stream,
};
use log::debug;
use tokio::prelude::{AsyncRead, AsyncWrite};
use crate::{
error::Error,
frame::{Flag, Flags, Frame, Type},
StreamId,
};
#[derive(Debug)]
pub struct StreamHandle {
id: StreamId,
state: StreamState,
max_recv_window: u32,
recv_window: u32,
send_window: u32,
read_buf: BytesMut,
write_buf: BytesMut,
window_update_frame_buf: VecDeque<(Flags, u32)>,
event_sender: Sender<StreamEvent>,
frame_receiver: Receiver<Frame>,
}
impl StreamHandle {
pub(crate) fn new(
id: StreamId,
event_sender: Sender<StreamEvent>,
frame_receiver: Receiver<Frame>,
state: StreamState,
recv_window_size: u32,
send_window_size: u32,
) -> StreamHandle {
assert!(state == StreamState::Init || state == StreamState::SynReceived);
StreamHandle {
id,
state,
max_recv_window: recv_window_size,
recv_window: recv_window_size,
send_window: send_window_size,
read_buf: BytesMut::default(),
write_buf: BytesMut::default(),
window_update_frame_buf: VecDeque::default(),
event_sender,
frame_receiver,
}
}
pub fn id(&self) -> StreamId {
self.id
}
pub fn state(&self) -> StreamState {
self.state
}
pub fn recv_window(&self) -> u32 {
self.recv_window
}
pub fn send_window(&self) -> u32 {
self.send_window
}
fn close(&mut self) -> Result<(), Error> {
match self.state {
StreamState::SynSent | StreamState::SynReceived | StreamState::Established => {
self.state = StreamState::LocalClosing;
self.send_close()?;
}
StreamState::RemoteClosing => {
self.state = StreamState::Closed;
self.send_close()?;
let event = StreamEvent::StateChanged((self.id, self.state));
self.send_event(event)?;
}
StreamState::Reset | StreamState::Closed => {
self.state = StreamState::Closed;
let event = StreamEvent::StateChanged((self.id, self.state));
self.send_event(event)?;
}
_ => {}
}
Ok(())
}
#[inline]
fn send_event(&mut self, event: StreamEvent) -> Result<(), Error> {
debug!("[{}] StreamHandle.send_event({:?})", self.id, event);
while let Some((flag, delta)) = self.window_update_frame_buf.pop_front() {
let event = StreamEvent::Frame(Frame::new_window_update(flag, self.id, delta));
if let Err(e) = self.event_sender.try_send(event) {
if e.is_full() {
self.window_update_frame_buf.push_front((flag, delta));
return Err(Error::WouldBlock);
} else {
return Err(Error::SessionShutdown);
}
}
}
if let Err(e) = self.event_sender.try_send(event) {
if e.is_full() {
return Err(Error::WouldBlock);
} else {
return Err(Error::SessionShutdown);
}
}
Ok(())
}
#[inline]
fn send_frame(&mut self, frame: Frame) -> Result<(), Error> {
let event = StreamEvent::Frame(frame);
self.send_event(event)
}
pub(crate) fn send_window_update(&mut self) -> Result<(), Error> {
let buf_len = self.read_buf.len() as u32;
let delta = self.max_recv_window - buf_len - self.recv_window;
let flags = self.get_flags();
if delta < (self.max_recv_window / 2) && flags.value() == 0 {
return Ok(());
}
self.recv_window += delta;
let frame = Frame::new_window_update(flags, self.id, delta);
match self.send_frame(frame) {
Err(ref e) if e == &Error::WouldBlock => {
self.window_update_frame_buf.push_back((flags, delta))
}
Err(e) => return Err(e),
_ => (),
}
Ok(())
}
fn send_data(&mut self, data: &[u8]) -> Result<(), Error> {
let flags = self.get_flags();
let frame = Frame::new_data(flags, self.id, Bytes::from(data));
self.send_frame(frame)
}
fn send_close(&mut self) -> Result<(), Error> {
let mut flags = self.get_flags();
flags.add(Flag::Fin);
let frame = Frame::new_window_update(flags, self.id, 0);
self.send_frame(frame)
}
fn process_flags(&mut self, flags: Flags) -> Result<(), Error> {
if flags.contains(Flag::Ack) && self.state == StreamState::SynSent {
self.state = StreamState::SynReceived;
}
let mut close_stream = false;
if flags.contains(Flag::Fin) {
match self.state {
StreamState::Init
| StreamState::SynSent
| StreamState::SynReceived
| StreamState::Established => {
self.state = StreamState::RemoteClosing;
}
StreamState::LocalClosing => {
self.state = StreamState::Closed;
close_stream = true;
}
_ => return Err(Error::UnexpectedFlag),
}
}
if flags.contains(Flag::Rst) {
self.state = StreamState::Reset;
close_stream = true;
}
if close_stream {
self.close()?;
}
Ok(())
}
fn get_flags(&mut self) -> Flags {
match self.state {
StreamState::Init => {
self.state = StreamState::SynSent;
Flags::from(Flag::Syn)
}
StreamState::SynReceived => {
self.state = StreamState::Established;
Flags::from(Flag::Ack)
}
_ => Flags::default(),
}
}
fn handle_frame(&mut self, frame: Frame) -> Result<(), Error> {
debug!("[{}] StreamHandle.handle_frame({:?})", self.id, frame);
match frame.ty() {
Type::WindowUpdate => {
self.handle_window_update(&frame)?;
}
Type::Data => {
self.handle_data(frame)?;
}
_ => {
return Err(Error::InvalidMsgType);
}
}
Ok(())
}
fn handle_window_update(&mut self, frame: &Frame) -> Result<(), Error> {
self.process_flags(frame.flags())?;
self.send_window += frame.length();
let n = ::std::cmp::min(self.send_window as usize, self.write_buf.len());
if n != 0 {
let b = self.write_buf.split_to(n);
let _ = self.write(&b);
} else {
task::current().notify();
}
Ok(())
}
fn handle_data(&mut self, frame: Frame) -> Result<(), Error> {
self.process_flags(frame.flags())?;
let length = frame.length();
if length > self.recv_window {
return Err(Error::RecvWindowExceeded);
}
let (_, body) = frame.into_parts();
if let Some(data) = body {
self.read_buf.extend_from_slice(&data);
}
self.recv_window -= length;
Ok(())
}
fn recv_frames(&mut self) -> Poll<(), Error> {
loop {
match self.state {
StreamState::RemoteClosing => {
return Err(Error::SubStreamRemoteClosing);
}
StreamState::Reset | StreamState::Closed => {
return Err(Error::SessionShutdown);
}
_ => {}
}
match self
.frame_receiver
.poll()
.map_err(|()| Error::SessionShutdown)?
{
Async::Ready(Some(frame)) => self.handle_frame(frame)?,
Async::Ready(None) => {
return Err(Error::SessionShutdown);
}
Async::NotReady => {
return Ok(Async::NotReady);
}
}
}
}
fn check_self_state(&mut self) -> Result<(), io::Error> {
if self.read_buf.is_empty() {
match self.state {
StreamState::RemoteClosing | StreamState::Closed => {
debug!("closed(EOF)");
self.shutdown()?;
Err(io::ErrorKind::UnexpectedEof.into())
}
StreamState::Reset => {
debug!("connection reset");
self.shutdown()?;
Err(io::ErrorKind::ConnectionReset.into())
}
_ => Ok(()),
}
} else {
Ok(())
}
}
}
impl io::Read for StreamHandle {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.check_self_state()?;
let rv = self.recv_frames();
debug!(
"[{}] StreamHandle.read() recv_frames() => {:?}, state: {:?}",
self.id, rv, self.state
);
self.check_self_state()?;
let n = ::std::cmp::min(buf.len(), self.read_buf.len());
if n == 0 {
return Err(io::ErrorKind::WouldBlock.into());
}
let b = self.read_buf.split_to(n);
debug!(
"[{}] StreamHandle.read({}), buf.len()={}, read_buf.len()={}",
self.id,
n,
buf.len(),
self.read_buf.len()
);
buf[..n].copy_from_slice(&b);
match self.state {
StreamState::RemoteClosing | StreamState::Closed | StreamState::Reset => (),
_ => {
if self.send_window_update().is_err() {
return Err(io::ErrorKind::BrokenPipe.into());
}
}
}
Ok(n)
}
}
impl io::Write for StreamHandle {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
debug!("[{}] StreamHandle.write({:?})", self.id, buf);
if let Err(e) = self.recv_frames() {
match e {
Error::SessionShutdown => return Err(io::ErrorKind::BrokenPipe.into()),
Error::UnexpectedFlag | Error::RecvWindowExceeded => {
return Err(io::ErrorKind::InvalidData.into());
}
Error::SubStreamRemoteClosing => (),
Error::WouldBlock => return Err(io::ErrorKind::WouldBlock.into()),
_ => unimplemented!(),
}
}
if self.state == StreamState::LocalClosing || self.state == StreamState::Closed {
return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"The local is closed and data cannot be written.",
));
}
if self.send_window == 0 {
return Err(io::ErrorKind::WouldBlock.into());
}
let n = ::std::cmp::min(self.send_window as usize, buf.len());
let data = &buf[0..n];
match self.send_data(data) {
Ok(_) => {
self.send_window -= n as u32;
self.write_buf.extend_from_slice(&buf[n..]);
Ok(buf.len())
}
Err(ref e) if e == &Error::WouldBlock => Err(io::ErrorKind::WouldBlock.into()),
_ => Err(io::ErrorKind::BrokenPipe.into()),
}
}
fn flush(&mut self) -> io::Result<()> {
debug!("[{}] StreamHandle.flush()", self.id);
if let Err(e) = self.recv_frames() {
match e {
Error::SessionShutdown => return Err(io::ErrorKind::BrokenPipe.into()),
Error::UnexpectedFlag | Error::RecvWindowExceeded => {
return Err(io::ErrorKind::InvalidData.into());
}
Error::SubStreamRemoteClosing => (),
Error::WouldBlock => return Err(io::ErrorKind::WouldBlock.into()),
_ => unimplemented!(),
}
}
let event = StreamEvent::Flush(self.id);
match self.send_event(event) {
Err(ref e) if e == &Error::WouldBlock => Err(io::ErrorKind::WouldBlock.into()),
Err(_) => Err(io::ErrorKind::BrokenPipe.into()),
Ok(()) => Ok(()),
}
}
}
impl AsyncRead for StreamHandle {}
impl AsyncWrite for StreamHandle {
fn shutdown(&mut self) -> Poll<(), io::Error> {
debug!("[{}] StreamHandle.shutdown()", self.id);
match self.close() {
Err(ref e) if e == &Error::WouldBlock => Err(io::ErrorKind::WouldBlock.into()),
Err(_) => Err(io::ErrorKind::BrokenPipe.into()),
Ok(()) => Ok(Async::Ready(())),
}
}
}
#[derive(Debug)]
pub(crate) enum StreamEvent {
Frame(Frame),
StateChanged((StreamId, StreamState)),
Flush(StreamId),
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum StreamState {
Init,
SynSent,
SynReceived,
Established,
LocalClosing,
RemoteClosing,
Closed,
Reset,
}