1use core::pin::Pin;
2use core::task::{Context, Poll, Poll::Pending, Poll::Ready};
3use futures::io::{AsyncWrite, AsyncWriteExt};
4use futures::sink::Sink;
5use std::mem::replace;
6
7use crate::packet::*;
9use crate::PinFut;
10use snafu::{ResultExt, Snafu};
11
12#[derive(Debug, Snafu)]
13pub enum Error {
14 #[snafu(display("Failed to send goodbye packet: {}", source))]
15 SendGoodbye { source: std::io::Error },
16
17 #[snafu(display("Failed to send packet: {}", source))]
18 Send { source: std::io::Error },
19
20 #[snafu(display("Failed to flush sink: {}", source))]
21 Flush { source: std::io::Error },
22
23 #[snafu(display("Error while closing sink: {}", source))]
24 Close { source: std::io::Error },
25}
26
27async fn send<W>(mut w: W, msg: Packet) -> (W, Result<(), Error>)
28where
29 W: AsyncWrite + Unpin + 'static,
30{
31 let h = msg.header();
32 let mut r = w.write_all(&h).await;
33 if r.is_ok() {
34 r = w.write_all(&msg.body).await;
35 }
36 (w, r.map(|_| ()).context(Send))
37}
38
39async fn send_goodbye<W>(mut w: W) -> (W, Result<(), Error>)
40where
41 W: AsyncWrite + Unpin + 'static,
42{
43 let r = w.write_all(&[0; 9]).await;
44 (w, r.map(|_| ()).context(SendGoodbye {}))
45}
46
47pub struct PacketSink<W> {
71 state: State<W>,
72}
73impl<W> PacketSink<W> {
74 pub fn new(w: W) -> PacketSink<W> {
75 PacketSink {
76 state: State::Ready(w),
77 }
78 }
79
80 pub fn into_inner(mut self) -> W {
81 match self.state.take() {
82 State::Ready(w) | State::Closing(w, _) | State::Closed(w) => w,
83 _ => panic!(),
84 }
85 }
86}
87
88enum State<W> {
89 Ready(W),
90 Sending(PinFut<(W, Result<(), Error>)>),
91 SendingGoodbye(PinFut<(W, Result<(), Error>)>),
92 Closing(W, Option<Error>),
93 Closed(W),
94 Invalid,
95}
96impl<W> State<W> {
97 fn take(&mut self) -> Self {
98 replace(self, State::Invalid)
99 }
100}
101
102fn flush<W>(state: State<W>, cx: &mut Context) -> (State<W>, Poll<Result<(), Error>>)
103where
104 W: AsyncWrite + Unpin + 'static,
105{
106 match state {
107 State::Ready(mut w) => {
108 let p = Pin::new(&mut w).poll_flush(cx).map(|r| r.context(Flush));
109 (State::Ready(w), p)
110 }
111 State::Sending(mut f) => match f.as_mut().poll(cx) {
112 Pending => (State::Sending(f), Pending),
113 Ready((w, Err(e))) => close(State::Closing(w, Some(e)), cx),
114 Ready((mut w, Ok(()))) => {
115 let p = Pin::new(&mut w).poll_flush(cx).map(|r| r.context(Flush));
116 (State::Ready(w), p)
117 }
118 },
119 _ => panic!(), }
121}
122
123fn close<W>(state: State<W>, cx: &mut Context) -> (State<W>, Poll<Result<(), Error>>)
124where
125 W: AsyncWrite + Unpin + 'static,
126{
127 match state {
128 State::Ready(w) => close(State::SendingGoodbye(Box::pin(send_goodbye(w))), cx),
129 State::Sending(mut f) => match f.as_mut().poll(cx) {
130 Pending => (State::Sending(f), Pending),
131 Ready((w, Ok(()))) => close(State::SendingGoodbye(Box::pin(send_goodbye(w))), cx),
132 Ready((w, Err(e))) => close(State::Closing(w, Some(e)), cx),
133 },
134 State::SendingGoodbye(mut f) => match f.as_mut().poll(cx) {
135 Pending => (State::SendingGoodbye(f), Pending),
136 Ready((w, Err(e))) => close(State::Closing(w, Some(e)), cx),
137 Ready((mut w, Ok(()))) => {
138 let p = Pin::new(&mut w).poll_close(cx).map(|r| r.context(Close));
139 (State::Closing(w, None), p)
140 }
141 },
142 State::Closing(mut w, e) => {
143 match (Pin::new(&mut w).poll_close(cx).map(|r| r.context(Close)), e) {
144 (Pending, e) => (State::Closing(w, e), Pending),
145 (Ready(r), None) => (State::Closed(w), Ready(r)),
146 (Ready(_), Some(e)) => (State::Closed(w), Ready(Err(e))), }
148 }
149
150 st @ State::Closed(_) => (st, Ready(Ok(()))),
151 _ => panic!(),
152 }
153}
154
155impl<W> Sink<Packet> for PacketSink<W>
156where
157 W: AsyncWrite + Unpin + 'static,
158{
159 type Error = Error;
160
161 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
162 self.poll_flush(cx)
163 }
164
165 fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
166 match self.state.take() {
167 State::Ready(w) => {
168 self.state = State::Sending(Box::pin(send(w, item)));
169 Ok(())
170 }
171 _ => panic!(),
172 }
173 }
174
175 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
176 let (state, poll) = flush(self.state.take(), cx);
177 self.state = state;
178 poll
179 }
180
181 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
182 let (state, poll) = close(self.state.take(), cx);
183 self.state = state;
184 poll
185 }
186}