xitca_postgres/driver/
generic.rs

1use core::{
2    future::{poll_fn, Future},
3    pin::Pin,
4    task::Poll,
5};
6
7use std::{
8    collections::VecDeque,
9    io,
10    sync::{Arc, Mutex},
11};
12
13use futures_core::task::__internal::AtomicWaker;
14use postgres_protocol::message::{backend, frontend};
15use xitca_io::{
16    bytes::{Buf, BufRead, BytesMut},
17    io::{AsyncIo, Interest},
18};
19use xitca_unsafe_collection::futures::{Select as _, SelectOutput};
20
21use crate::error::{DriverDown, Error};
22
23use super::codec::{Response, ResponseMessage, ResponseSender, SenderState};
24
25type PagedBytesMut = xitca_unsafe_collection::bytes::PagedBytesMut<4096>;
26
27const INTEREST_READ_WRITE: Interest = Interest::READABLE.add(Interest::WRITABLE);
28
29pub(crate) struct DriverTx(Arc<SharedState>);
30
31impl Drop for DriverTx {
32    fn drop(&mut self) {
33        {
34            let mut state = self.0.guarded.lock().unwrap();
35            frontend::terminate(&mut state.buf);
36            state.closed = true;
37        }
38        self.0.waker.wake();
39    }
40}
41
42impl DriverTx {
43    pub(crate) fn is_closed(&self) -> bool {
44        Arc::strong_count(&self.0) == 1
45    }
46
47    pub(crate) fn send_one_way<F>(&self, func: F) -> Result<(), Error>
48    where
49        F: FnOnce(&mut BytesMut) -> Result<(), Error>,
50    {
51        self._send(func, |_| {})?;
52        Ok(())
53    }
54
55    pub(crate) fn send<F, O>(&self, func: F, msg_count: usize) -> Result<(O, Response), Error>
56    where
57        F: FnOnce(&mut BytesMut) -> Result<O, Error>,
58    {
59        self._send(func, |inner| {
60            let (tx, rx) = super::codec::request_pair(msg_count);
61            inner.res.push_back(tx);
62            rx
63        })
64    }
65
66    fn _send<F, F2, O, T>(&self, func: F, on_send: F2) -> Result<(O, T), Error>
67    where
68        F: FnOnce(&mut BytesMut) -> Result<O, Error>,
69        F2: FnOnce(&mut State) -> T,
70    {
71        let mut inner = self.0.guarded.lock().unwrap();
72
73        if inner.closed {
74            return Err(DriverDown.into());
75        }
76
77        let len = inner.buf.len();
78
79        let o = func(&mut inner.buf).inspect_err(|_| inner.buf.truncate(len))?;
80        let t = on_send(&mut inner);
81
82        drop(inner);
83        self.0.waker.wake();
84
85        Ok((o, t))
86    }
87}
88
89pub(crate) struct SharedState {
90    guarded: Mutex<State>,
91    waker: AtomicWaker,
92}
93
94impl SharedState {
95    async fn wait(&self) -> WaitState {
96        poll_fn(|cx| {
97            let inner = self.guarded.lock().unwrap();
98            if !inner.buf.is_empty() {
99                Poll::Ready(WaitState::WantWrite)
100            } else if inner.closed {
101                Poll::Ready(WaitState::WantClose)
102            } else {
103                drop(inner);
104                self.waker.register(cx.waker());
105                Poll::Pending
106            }
107        })
108        .await
109    }
110}
111
112enum WaitState {
113    WantWrite,
114    WantClose,
115}
116
117struct State {
118    closed: bool,
119    buf: BytesMut,
120    res: VecDeque<ResponseSender>,
121}
122
123pub struct GenericDriver<Io> {
124    io: Io,
125    read_buf: PagedBytesMut,
126    shared_state: Arc<SharedState>,
127    read_state: ReadState,
128    write_state: WriteState,
129}
130
131// in case driver is dropped without closing the shared state
132impl<Io> Drop for GenericDriver<Io> {
133    fn drop(&mut self) {
134        self.shared_state.guarded.lock().unwrap().closed = true;
135    }
136}
137
138enum WriteState {
139    Waiting,
140    WantWrite,
141    WantFlush,
142    Closed(Option<io::Error>),
143}
144
145enum ReadState {
146    WantRead,
147    Closed(Option<io::Error>),
148}
149
150impl<Io> GenericDriver<Io>
151where
152    Io: AsyncIo + Send,
153{
154    pub(crate) fn new(io: Io) -> (Self, DriverTx) {
155        let state = Arc::new(SharedState {
156            guarded: Mutex::new(State {
157                closed: false,
158                buf: BytesMut::new(),
159                res: VecDeque::new(),
160            }),
161            waker: AtomicWaker::new(),
162        });
163
164        (
165            Self {
166                io,
167                read_buf: PagedBytesMut::new(),
168                shared_state: state.clone(),
169                read_state: ReadState::WantRead,
170                write_state: WriteState::Waiting,
171            },
172            DriverTx(state),
173        )
174    }
175
176    pub(crate) async fn send(&mut self, msg: BytesMut) -> Result<(), Error> {
177        self.shared_state.guarded.lock().unwrap().buf.extend_from_slice(&msg);
178        self.write_state = WriteState::WantWrite;
179        loop {
180            self.try_write()?;
181            if matches!(self.write_state, WriteState::Waiting) {
182                return Ok(());
183            }
184            self.io.ready(Interest::WRITABLE).await?;
185        }
186    }
187
188    pub(crate) fn recv(&mut self) -> impl Future<Output = Result<backend::Message, Error>> + Send + '_ {
189        self.recv_with(|buf| backend::Message::parse(buf).map_err(Error::from).transpose())
190    }
191
192    pub(crate) async fn try_next(&mut self) -> Result<Option<backend::Message>, Error> {
193        loop {
194            if let Some(msg) = self.try_decode()? {
195                return Ok(Some(msg));
196            }
197
198            let ready = match (&mut self.read_state, &mut self.write_state) {
199                (ReadState::WantRead, WriteState::Waiting) => {
200                    match self.shared_state.wait().select(self.io.ready(Interest::READABLE)).await {
201                        SelectOutput::A(WaitState::WantWrite) => {
202                            self.write_state = WriteState::WantWrite;
203                            self.io.ready(INTEREST_READ_WRITE).await?
204                        }
205                        SelectOutput::A(WaitState::WantClose) => {
206                            self.write_state = WriteState::Closed(None);
207                            continue;
208                        }
209                        SelectOutput::B(ready) => ready?,
210                    }
211                }
212                (ReadState::WantRead, WriteState::WantWrite) => self.io.ready(INTEREST_READ_WRITE).await?,
213                (ReadState::WantRead, WriteState::WantFlush) => {
214                    // before flush io do a quick buffer check and go into write io state if possible.
215                    if !self.shared_state.guarded.lock().unwrap().buf.is_empty() {
216                        self.write_state = WriteState::WantWrite;
217                    }
218                    self.io.ready(INTEREST_READ_WRITE).await?
219                }
220                (ReadState::WantRead, WriteState::Closed(_)) => self.io.ready(Interest::READABLE).await?,
221                (ReadState::Closed(_), WriteState::WantFlush | WriteState::WantWrite) => {
222                    self.io.ready(Interest::WRITABLE).await?
223                }
224                (ReadState::Closed(_), WriteState::Waiting) => match self.shared_state.wait().await {
225                    WaitState::WantWrite => {
226                        self.write_state = WriteState::WantWrite;
227                        self.io.ready(Interest::WRITABLE).await?
228                    }
229                    WaitState::WantClose => {
230                        self.write_state = WriteState::Closed(None);
231                        continue;
232                    }
233                },
234                (ReadState::Closed(None), WriteState::Closed(None)) => {
235                    poll_fn(|cx| Pin::new(&mut self.io).poll_shutdown(cx)).await?;
236                    return Ok(None);
237                }
238                (ReadState::Closed(read_err), WriteState::Closed(write_err)) => {
239                    return Err(Error::driver_io(read_err.take(), write_err.take()))
240                }
241            };
242
243            if ready.is_readable() {
244                if let Err(e) = self.try_read() {
245                    self.on_read_err(e);
246                };
247            }
248
249            if ready.is_writable() {
250                if let Err(e) = self.try_write() {
251                    self.on_write_err(e);
252                }
253            }
254        }
255    }
256
257    async fn recv_with<F, O>(&mut self, mut func: F) -> Result<O, Error>
258    where
259        F: FnMut(&mut BytesMut) -> Option<Result<O, Error>>,
260    {
261        loop {
262            if let Some(o) = func(self.read_buf.get_mut()) {
263                return o;
264            }
265            self.io.ready(Interest::READABLE).await?;
266            self.try_read()?;
267        }
268    }
269
270    fn try_read(&mut self) -> io::Result<()> {
271        self.read_buf.do_io(&mut self.io)
272    }
273
274    fn try_write(&mut self) -> io::Result<()> {
275        loop {
276            match self.write_state {
277                WriteState::WantFlush => {
278                    match io::Write::flush(&mut self.io) {
279                        Ok(_) => self.write_state = WriteState::Waiting,
280                        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
281                        Err(e) => return Err(e),
282                    }
283                    break;
284                }
285                WriteState::WantWrite => {
286                    let mut inner = self.shared_state.guarded.lock().unwrap();
287
288                    match io::Write::write(&mut self.io, &inner.buf) {
289                        Ok(0) => return Err(io::ErrorKind::WriteZero.into()),
290                        Ok(n) => {
291                            inner.buf.advance(n);
292
293                            if inner.buf.is_empty() {
294                                self.write_state = WriteState::WantFlush;
295                            }
296                        }
297                        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => break,
298                        Err(e) => return Err(e),
299                    }
300                }
301                _ => unreachable!("try_write must not be called when WriteState is wait or closed"),
302            }
303        }
304
305        Ok(())
306    }
307
308    #[cold]
309    fn on_read_err(&mut self, e: io::Error) {
310        let reason = (e.kind() != io::ErrorKind::UnexpectedEof).then_some(e);
311        self.read_state = ReadState::Closed(reason);
312    }
313
314    #[cold]
315    fn on_write_err(&mut self, e: io::Error) {
316        {
317            // when write error occur the driver would go into half close state(read only).
318            // clearing write_buf would drop all pending requests in it and hint the driver
319            // no future Interest::WRITABLE should be passed to AsyncIo::ready method.
320            let mut inner = self.shared_state.guarded.lock().unwrap();
321            inner.buf.clear();
322            // close shared state early so driver tx can observe the shutdown in first hand
323            inner.closed = true;
324        }
325        self.write_state = WriteState::Closed(Some(e));
326    }
327
328    fn try_decode(&mut self) -> Result<Option<backend::Message>, Error> {
329        while let Some(res) = ResponseMessage::try_from_buf(self.read_buf.get_mut())? {
330            match res {
331                ResponseMessage::Normal(mut msg) => {
332                    let mut inner = self.shared_state.guarded.lock().unwrap();
333                    let front = inner.res.front_mut().ok_or_else(|| msg.parse_error())?;
334                    match front.send(msg) {
335                        SenderState::Finish => {
336                            inner.res.pop_front();
337                        }
338                        SenderState::Continue => {}
339                    }
340                }
341                ResponseMessage::Async(msg) => return Ok(Some(msg)),
342            }
343        }
344        Ok(None)
345    }
346}