1use core::{
2 future::{poll_fn, Future},
3 pin::Pin,
4 task::Poll,
5};
6
7use std::{
8 collections::VecDeque,
9 io,
10 sync::{Arc, Mutex},
11};
12
13use futures_core::task::__internal::AtomicWaker;
14use postgres_protocol::message::{backend, frontend};
15use xitca_io::{
16 bytes::{Buf, BufRead, BytesMut},
17 io::{AsyncIo, Interest},
18};
19use xitca_unsafe_collection::futures::{Select as _, SelectOutput};
20
21use crate::error::{DriverDown, Error};
22
23use super::codec::{Response, ResponseMessage, ResponseSender, SenderState};
24
25type PagedBytesMut = xitca_unsafe_collection::bytes::PagedBytesMut<4096>;
26
27const INTEREST_READ_WRITE: Interest = Interest::READABLE.add(Interest::WRITABLE);
28
29pub(crate) struct DriverTx(Arc<SharedState>);
30
31impl Drop for DriverTx {
32 fn drop(&mut self) {
33 {
34 let mut state = self.0.guarded.lock().unwrap();
35 frontend::terminate(&mut state.buf);
36 state.closed = true;
37 }
38 self.0.waker.wake();
39 }
40}
41
42impl DriverTx {
43 pub(crate) fn is_closed(&self) -> bool {
44 Arc::strong_count(&self.0) == 1
45 }
46
47 pub(crate) fn send_one_way<F>(&self, func: F) -> Result<(), Error>
48 where
49 F: FnOnce(&mut BytesMut) -> Result<(), Error>,
50 {
51 self._send(func, |_| {})?;
52 Ok(())
53 }
54
55 pub(crate) fn send<F, O>(&self, func: F, msg_count: usize) -> Result<(O, Response), Error>
56 where
57 F: FnOnce(&mut BytesMut) -> Result<O, Error>,
58 {
59 self._send(func, |inner| {
60 let (tx, rx) = super::codec::request_pair(msg_count);
61 inner.res.push_back(tx);
62 rx
63 })
64 }
65
66 fn _send<F, F2, O, T>(&self, func: F, on_send: F2) -> Result<(O, T), Error>
67 where
68 F: FnOnce(&mut BytesMut) -> Result<O, Error>,
69 F2: FnOnce(&mut State) -> T,
70 {
71 let mut inner = self.0.guarded.lock().unwrap();
72
73 if inner.closed {
74 return Err(DriverDown.into());
75 }
76
77 let len = inner.buf.len();
78
79 let o = func(&mut inner.buf).inspect_err(|_| inner.buf.truncate(len))?;
80 let t = on_send(&mut inner);
81
82 drop(inner);
83 self.0.waker.wake();
84
85 Ok((o, t))
86 }
87}
88
89pub(crate) struct SharedState {
90 guarded: Mutex<State>,
91 waker: AtomicWaker,
92}
93
94impl SharedState {
95 async fn wait(&self) -> WaitState {
96 poll_fn(|cx| {
97 let inner = self.guarded.lock().unwrap();
98 if !inner.buf.is_empty() {
99 Poll::Ready(WaitState::WantWrite)
100 } else if inner.closed {
101 Poll::Ready(WaitState::WantClose)
102 } else {
103 drop(inner);
104 self.waker.register(cx.waker());
105 Poll::Pending
106 }
107 })
108 .await
109 }
110}
111
112enum WaitState {
113 WantWrite,
114 WantClose,
115}
116
117struct State {
118 closed: bool,
119 buf: BytesMut,
120 res: VecDeque<ResponseSender>,
121}
122
123pub struct GenericDriver<Io> {
124 io: Io,
125 read_buf: PagedBytesMut,
126 shared_state: Arc<SharedState>,
127 read_state: ReadState,
128 write_state: WriteState,
129}
130
131impl<Io> Drop for GenericDriver<Io> {
133 fn drop(&mut self) {
134 self.shared_state.guarded.lock().unwrap().closed = true;
135 }
136}
137
138enum WriteState {
139 Waiting,
140 WantWrite,
141 WantFlush,
142 Closed(Option<io::Error>),
143}
144
145enum ReadState {
146 WantRead,
147 Closed(Option<io::Error>),
148}
149
150impl<Io> GenericDriver<Io>
151where
152 Io: AsyncIo + Send,
153{
154 pub(crate) fn new(io: Io) -> (Self, DriverTx) {
155 let state = Arc::new(SharedState {
156 guarded: Mutex::new(State {
157 closed: false,
158 buf: BytesMut::new(),
159 res: VecDeque::new(),
160 }),
161 waker: AtomicWaker::new(),
162 });
163
164 (
165 Self {
166 io,
167 read_buf: PagedBytesMut::new(),
168 shared_state: state.clone(),
169 read_state: ReadState::WantRead,
170 write_state: WriteState::Waiting,
171 },
172 DriverTx(state),
173 )
174 }
175
176 pub(crate) async fn send(&mut self, msg: BytesMut) -> Result<(), Error> {
177 self.shared_state.guarded.lock().unwrap().buf.extend_from_slice(&msg);
178 self.write_state = WriteState::WantWrite;
179 loop {
180 self.try_write()?;
181 if matches!(self.write_state, WriteState::Waiting) {
182 return Ok(());
183 }
184 self.io.ready(Interest::WRITABLE).await?;
185 }
186 }
187
188 pub(crate) fn recv(&mut self) -> impl Future<Output = Result<backend::Message, Error>> + Send + '_ {
189 self.recv_with(|buf| backend::Message::parse(buf).map_err(Error::from).transpose())
190 }
191
192 pub(crate) async fn try_next(&mut self) -> Result<Option<backend::Message>, Error> {
193 loop {
194 if let Some(msg) = self.try_decode()? {
195 return Ok(Some(msg));
196 }
197
198 let ready = match (&mut self.read_state, &mut self.write_state) {
199 (ReadState::WantRead, WriteState::Waiting) => {
200 match self.shared_state.wait().select(self.io.ready(Interest::READABLE)).await {
201 SelectOutput::A(WaitState::WantWrite) => {
202 self.write_state = WriteState::WantWrite;
203 self.io.ready(INTEREST_READ_WRITE).await?
204 }
205 SelectOutput::A(WaitState::WantClose) => {
206 self.write_state = WriteState::Closed(None);
207 continue;
208 }
209 SelectOutput::B(ready) => ready?,
210 }
211 }
212 (ReadState::WantRead, WriteState::WantWrite) => self.io.ready(INTEREST_READ_WRITE).await?,
213 (ReadState::WantRead, WriteState::WantFlush) => {
214 if !self.shared_state.guarded.lock().unwrap().buf.is_empty() {
216 self.write_state = WriteState::WantWrite;
217 }
218 self.io.ready(INTEREST_READ_WRITE).await?
219 }
220 (ReadState::WantRead, WriteState::Closed(_)) => self.io.ready(Interest::READABLE).await?,
221 (ReadState::Closed(_), WriteState::WantFlush | WriteState::WantWrite) => {
222 self.io.ready(Interest::WRITABLE).await?
223 }
224 (ReadState::Closed(_), WriteState::Waiting) => match self.shared_state.wait().await {
225 WaitState::WantWrite => {
226 self.write_state = WriteState::WantWrite;
227 self.io.ready(Interest::WRITABLE).await?
228 }
229 WaitState::WantClose => {
230 self.write_state = WriteState::Closed(None);
231 continue;
232 }
233 },
234 (ReadState::Closed(None), WriteState::Closed(None)) => {
235 poll_fn(|cx| Pin::new(&mut self.io).poll_shutdown(cx)).await?;
236 return Ok(None);
237 }
238 (ReadState::Closed(read_err), WriteState::Closed(write_err)) => {
239 return Err(Error::driver_io(read_err.take(), write_err.take()))
240 }
241 };
242
243 if ready.is_readable() {
244 if let Err(e) = self.try_read() {
245 self.on_read_err(e);
246 };
247 }
248
249 if ready.is_writable() {
250 if let Err(e) = self.try_write() {
251 self.on_write_err(e);
252 }
253 }
254 }
255 }
256
257 async fn recv_with<F, O>(&mut self, mut func: F) -> Result<O, Error>
258 where
259 F: FnMut(&mut BytesMut) -> Option<Result<O, Error>>,
260 {
261 loop {
262 if let Some(o) = func(self.read_buf.get_mut()) {
263 return o;
264 }
265 self.io.ready(Interest::READABLE).await?;
266 self.try_read()?;
267 }
268 }
269
270 fn try_read(&mut self) -> io::Result<()> {
271 self.read_buf.do_io(&mut self.io)
272 }
273
274 fn try_write(&mut self) -> io::Result<()> {
275 loop {
276 match self.write_state {
277 WriteState::WantFlush => {
278 match io::Write::flush(&mut self.io) {
279 Ok(_) => self.write_state = WriteState::Waiting,
280 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
281 Err(e) => return Err(e),
282 }
283 break;
284 }
285 WriteState::WantWrite => {
286 let mut inner = self.shared_state.guarded.lock().unwrap();
287
288 match io::Write::write(&mut self.io, &inner.buf) {
289 Ok(0) => return Err(io::ErrorKind::WriteZero.into()),
290 Ok(n) => {
291 inner.buf.advance(n);
292
293 if inner.buf.is_empty() {
294 self.write_state = WriteState::WantFlush;
295 }
296 }
297 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => break,
298 Err(e) => return Err(e),
299 }
300 }
301 _ => unreachable!("try_write must not be called when WriteState is wait or closed"),
302 }
303 }
304
305 Ok(())
306 }
307
308 #[cold]
309 fn on_read_err(&mut self, e: io::Error) {
310 let reason = (e.kind() != io::ErrorKind::UnexpectedEof).then_some(e);
311 self.read_state = ReadState::Closed(reason);
312 }
313
314 #[cold]
315 fn on_write_err(&mut self, e: io::Error) {
316 {
317 let mut inner = self.shared_state.guarded.lock().unwrap();
321 inner.buf.clear();
322 inner.closed = true;
324 }
325 self.write_state = WriteState::Closed(Some(e));
326 }
327
328 fn try_decode(&mut self) -> Result<Option<backend::Message>, Error> {
329 while let Some(res) = ResponseMessage::try_from_buf(self.read_buf.get_mut())? {
330 match res {
331 ResponseMessage::Normal(mut msg) => {
332 let mut inner = self.shared_state.guarded.lock().unwrap();
333 let front = inner.res.front_mut().ok_or_else(|| msg.parse_error())?;
334 match front.send(msg) {
335 SenderState::Finish => {
336 inner.res.pop_front();
337 }
338 SenderState::Continue => {}
339 }
340 }
341 ResponseMessage::Async(msg) => return Ok(Some(msg)),
342 }
343 }
344 Ok(None)
345 }
346}