xitca_postgres/driver/
codec.rs

1pub(crate) mod encode;
2pub(crate) mod response;
3
4use core::{
5    future::{poll_fn, Future},
6    task::{ready, Context, Poll},
7};
8
9use postgres_protocol::message::backend;
10use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
11use xitca_io::bytes::BytesMut;
12
13use crate::{
14    error::{Completed, DriverDownReceiving, Error},
15    types::BorrowToSql,
16};
17
18use super::DriverTx;
19
20pub(super) fn request_pair(msg_count: usize) -> (ResponseSender, Response) {
21    let (tx, rx) = unbounded_channel();
22    (
23        ResponseSender { tx, msg_count },
24        Response {
25            rx,
26            buf: BytesMut::new(),
27            complete: false,
28        },
29    )
30}
31
32#[derive(Debug)]
33pub struct Response {
34    rx: ResponseReceiver,
35    buf: BytesMut,
36    complete: bool,
37}
38
39impl Response {
40    pub(crate) fn blocking_recv(&mut self) -> Result<backend::Message, Error> {
41        if self.buf.is_empty() {
42            let res = self.rx.blocking_recv();
43            self.on_recv(res)?;
44        }
45        self.parse_message()
46    }
47
48    pub(crate) fn recv(&mut self) -> impl Future<Output = Result<backend::Message, Error>> + Send + '_ {
49        poll_fn(|cx| self.poll_recv(cx))
50    }
51
52    pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<backend::Message, Error>> {
53        if self.buf.is_empty() {
54            let res = ready!(self.rx.poll_recv(cx));
55            self.on_recv(res)?;
56        }
57        Poll::Ready(self.parse_message())
58    }
59
60    pub(crate) fn try_into_row_affected(mut self) -> impl Future<Output = Result<u64, Error>> + Send {
61        let mut rows = 0;
62        poll_fn(move |cx| {
63            ready!(self.poll_try_into_ready(&mut rows, cx))?;
64            Poll::Ready(Ok(rows))
65        })
66    }
67
68    pub(crate) fn try_into_row_affected_blocking(mut self) -> Result<u64, Error> {
69        let mut rows = 0;
70        loop {
71            match self.blocking_recv()? {
72                backend::Message::BindComplete
73                | backend::Message::NoData
74                | backend::Message::ParseComplete
75                | backend::Message::ParameterDescription(_)
76                | backend::Message::RowDescription(_)
77                | backend::Message::DataRow(_)
78                | backend::Message::EmptyQueryResponse
79                | backend::Message::PortalSuspended => {}
80                backend::Message::CommandComplete(body) => {
81                    rows = body_to_affected_rows(&body)?;
82                }
83                backend::Message::ReadyForQuery(_) => return Ok(rows),
84                _ => return Err(Error::unexpected()),
85            }
86        }
87    }
88
89    pub(crate) fn poll_try_into_ready(&mut self, rows: &mut u64, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
90        loop {
91            match ready!(self.poll_recv(cx))? {
92                backend::Message::BindComplete
93                | backend::Message::NoData
94                | backend::Message::ParseComplete
95                | backend::Message::ParameterDescription(_)
96                | backend::Message::RowDescription(_)
97                | backend::Message::DataRow(_)
98                | backend::Message::EmptyQueryResponse
99                | backend::Message::PortalSuspended => {}
100                backend::Message::CommandComplete(body) => {
101                    *rows = body_to_affected_rows(&body)?;
102                }
103                backend::Message::ReadyForQuery(_) => return Poll::Ready(Ok(())),
104                _ => return Poll::Ready(Err(Error::unexpected())),
105            }
106        }
107    }
108
109    fn on_recv(&mut self, res: Option<BytesMessage>) -> Result<(), Error> {
110        match res {
111            Some(msg) => {
112                self.complete = msg.complete;
113                self.buf = msg.buf;
114                Ok(())
115            }
116            None => {
117                return if self.complete {
118                    Err(Completed.into())
119                } else {
120                    Err(DriverDownReceiving.into())
121                }
122            }
123        }
124    }
125
126    fn parse_message(&mut self) -> Result<backend::Message, Error> {
127        match backend::Message::parse(&mut self.buf)?.expect("must not parse message from empty buffer.") {
128            backend::Message::ErrorResponse(body) => Err(Error::db(body.fields())),
129            msg => Ok(msg),
130        }
131    }
132}
133
134// Extract the number of rows affected.
135pub(crate) fn body_to_affected_rows(body: &backend::CommandCompleteBody) -> Result<u64, Error> {
136    body.tag()
137        .map_err(|_| Error::todo())
138        .map(|r| r.rsplit(' ').next().unwrap().parse().unwrap_or(0))
139}
140
141#[derive(Debug)]
142pub(crate) struct ResponseSender {
143    tx: UnboundedSender<BytesMessage>,
144    msg_count: usize,
145}
146
147pub(super) enum SenderState {
148    Continue,
149    Finish,
150}
151
152impl ResponseSender {
153    pub(super) fn send(&mut self, msg: BytesMessage) -> SenderState {
154        debug_assert!(self.msg_count > 0);
155
156        if msg.complete {
157            self.msg_count -= 1;
158        }
159
160        let _ = self.tx.send(msg);
161
162        if self.msg_count == 0 {
163            SenderState::Finish
164        } else {
165            SenderState::Continue
166        }
167    }
168}
169
170// TODO: remove this lint.
171#[allow(dead_code)]
172pub(super) type ResponseReceiver = UnboundedReceiver<BytesMessage>;
173
174pub(super) struct BytesMessage {
175    buf: BytesMut,
176    complete: bool,
177}
178
179impl BytesMessage {
180    #[cold]
181    #[inline(never)]
182    pub(super) fn parse_error(&mut self) -> Error {
183        match backend::Message::parse(&mut self.buf) {
184            Err(e) => Error::from(e),
185            Ok(Some(backend::Message::ErrorResponse(body))) => Error::db(body.fields()),
186            _ => Error::unexpected(),
187        }
188    }
189}
190
191pub(super) enum ResponseMessage {
192    Normal(BytesMessage),
193    Async(backend::Message),
194}
195
196impl ResponseMessage {
197    pub(crate) fn try_from_buf(buf: &mut BytesMut) -> Result<Option<Self>, Error> {
198        let mut tail = 0;
199        let mut complete = false;
200
201        loop {
202            let slice = &buf[tail..];
203            let Some(header) = backend::Header::parse(slice)? else {
204                break;
205            };
206            let len = header.len() as usize + 1;
207
208            if slice.len() < len {
209                break;
210            }
211
212            match header.tag() {
213                backend::NOTICE_RESPONSE_TAG | backend::NOTIFICATION_RESPONSE_TAG | backend::PARAMETER_STATUS_TAG => {
214                    if tail > 0 {
215                        break;
216                    }
217                    let message = backend::Message::parse(buf)?
218                        .expect("buffer contains at least one Message. parser must produce Some");
219                    return Ok(Some(ResponseMessage::Async(message)));
220                }
221                tag => {
222                    tail += len;
223                    if matches!(tag, backend::READY_FOR_QUERY_TAG) {
224                        complete = true;
225                        break;
226                    }
227                }
228            }
229        }
230
231        if tail == 0 {
232            Ok(None)
233        } else {
234            Ok(Some(ResponseMessage::Normal(BytesMessage {
235                buf: buf.split_to(tail),
236                complete,
237            })))
238        }
239    }
240}
241
242/// traits for converting typed parameters into exact sized iterator where it yields
243/// item can be converted in binary format of postgres type.
244pub trait AsParams: IntoIterator<IntoIter: ExactSizeIterator<Item: BorrowToSql>> {}
245
246impl<I> AsParams for I where I: IntoIterator<IntoIter: ExactSizeIterator<Item: BorrowToSql>> {}
247
248mod sealed {
249    pub trait Sealed {}
250}