Skip to main content

xitca_postgres/driver/
codec.rs

1pub(crate) mod encode;
2pub(crate) mod response;
3
4use core::{
5    future::{Future, poll_fn},
6    task::{Context, Poll, ready},
7};
8
9use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
10use xitca_io::bytes::BytesMut;
11
12use crate::{
13    error::{ClosedByDriver, Error},
14    protocol::message::backend,
15    types::BorrowToSql,
16};
17
18pub(super) fn request_pair() -> (ResponseSender, Response) {
19    let (tx, rx) = unbounded_channel();
20    (
21        tx,
22        Response {
23            rx,
24            buf: BytesMut::new(),
25        },
26    )
27}
28
29#[derive(Debug)]
30pub struct Response {
31    rx: ResponseReceiver,
32    buf: BytesMut,
33}
34
35impl Response {
36    pub(crate) fn blocking_recv(&mut self) -> Result<backend::Message, Error> {
37        if self.buf.is_empty() {
38            self.buf = self.rx.blocking_recv().ok_or_else(|| Error::from(ClosedByDriver))?;
39        }
40        self.parse_message()
41    }
42
43    pub(crate) fn recv(&mut self) -> impl Future<Output = Result<backend::Message, Error>> + Send + '_ {
44        poll_fn(|cx| self.poll_recv(cx))
45    }
46
47    pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<backend::Message, Error>> {
48        if self.buf.is_empty() {
49            self.buf = ready!(self.rx.poll_recv(cx)).ok_or_else(|| Error::from(ClosedByDriver))?;
50        }
51        Poll::Ready(self.parse_message())
52    }
53
54    pub(crate) fn try_into_row_affected(mut self) -> impl Future<Output = Result<u64, Error>> + Send {
55        let mut rows = 0;
56        poll_fn(move |cx| {
57            ready!(self.poll_try_into_ready(&mut rows, cx))?;
58            Poll::Ready(Ok(rows))
59        })
60    }
61
62    pub(crate) fn try_into_row_affected_blocking(mut self) -> Result<u64, Error> {
63        let mut rows = 0;
64        loop {
65            match self.blocking_recv()? {
66                backend::Message::BindComplete
67                | backend::Message::NoData
68                | backend::Message::ParseComplete
69                | backend::Message::ParameterDescription(_)
70                | backend::Message::RowDescription(_)
71                | backend::Message::DataRow(_)
72                | backend::Message::EmptyQueryResponse
73                | backend::Message::PortalSuspended => {}
74                backend::Message::CommandComplete(body) => {
75                    rows = body_to_affected_rows(&body)?;
76                }
77                backend::Message::ReadyForQuery(_) => return Ok(rows),
78                _ => return Err(Error::unexpected()),
79            }
80        }
81    }
82
83    pub(crate) fn poll_try_into_ready(&mut self, rows: &mut u64, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
84        loop {
85            match ready!(self.poll_recv(cx))? {
86                backend::Message::BindComplete
87                | backend::Message::NoData
88                | backend::Message::ParseComplete
89                | backend::Message::ParameterDescription(_)
90                | backend::Message::RowDescription(_)
91                | backend::Message::DataRow(_)
92                | backend::Message::EmptyQueryResponse
93                | backend::Message::PortalSuspended => {}
94                backend::Message::CommandComplete(body) => {
95                    *rows = body_to_affected_rows(&body)?;
96                }
97                backend::Message::ReadyForQuery(_) => return Poll::Ready(Ok(())),
98                _ => return Poll::Ready(Err(Error::unexpected())),
99            }
100        }
101    }
102
103    fn parse_message(&mut self) -> Result<backend::Message, Error> {
104        match backend::Message::parse(&mut self.buf)?.expect("must not parse message from empty buffer.") {
105            backend::Message::ErrorResponse(body) => Err(Error::db(body.fields())),
106            msg => Ok(msg),
107        }
108    }
109}
110
111// Extract the number of rows affected.
112pub(crate) fn body_to_affected_rows(body: &backend::CommandCompleteBody) -> Result<u64, Error> {
113    body.tag()
114        .map_err(|_| Error::todo())
115        .map(|r| r.rsplit(' ').next().unwrap().parse().unwrap_or(0))
116}
117
118pub(super) type ResponseSender = UnboundedSender<BytesMut>;
119
120// TODO: remove this lint.
121#[allow(dead_code)]
122pub(super) type ResponseReceiver = UnboundedReceiver<BytesMut>;
123
124pub(super) struct BytesMessage {
125    pub(super) buf: BytesMut,
126    pub(super) complete: bool,
127}
128
129impl BytesMessage {
130    #[cold]
131    #[inline(never)]
132    pub(super) fn parse_error(&mut self) -> Error {
133        match backend::Message::parse(&mut self.buf) {
134            Err(e) => Error::from(e),
135            Ok(Some(backend::Message::ErrorResponse(body))) => Error::db(body.fields()),
136            _ => Error::unexpected(),
137        }
138    }
139}
140
141pub(super) enum ResponseMessage {
142    Normal(BytesMessage),
143    Async(backend::Message),
144}
145
146impl ResponseMessage {
147    pub(crate) fn try_from_buf(buf: &mut BytesMut) -> Result<Option<Self>, Error> {
148        let mut tail = 0;
149        let mut complete = false;
150
151        loop {
152            let slice = &buf[tail..];
153            let Some(header) = backend::Header::parse(slice)? else {
154                break;
155            };
156            let len = header.len() as usize + 1;
157
158            if slice.len() < len {
159                break;
160            }
161
162            match header.tag() {
163                backend::NOTICE_RESPONSE_TAG | backend::NOTIFICATION_RESPONSE_TAG | backend::PARAMETER_STATUS_TAG => {
164                    if tail > 0 {
165                        break;
166                    }
167                    let message = backend::Message::parse(buf)?
168                        .expect("buffer contains at least one Message. parser must produce Some");
169                    return Ok(Some(ResponseMessage::Async(message)));
170                }
171                tag => {
172                    tail += len;
173                    if matches!(tag, backend::READY_FOR_QUERY_TAG) {
174                        complete = true;
175                        break;
176                    }
177                }
178            }
179        }
180
181        if tail == 0 {
182            Ok(None)
183        } else {
184            Ok(Some(ResponseMessage::Normal(BytesMessage {
185                buf: buf.split_to(tail),
186                complete,
187            })))
188        }
189    }
190}
191
192/// traits for converting typed parameters into exact sized iterator where it yields
193/// item can be converted in binary format of postgres type.
194pub trait AsParams: IntoIterator<IntoIter: ExactSizeIterator<Item: BorrowToSql> + Clone> {}
195
196impl<I> AsParams for I where I: IntoIterator<IntoIter: ExactSizeIterator<Item: BorrowToSql> + Clone> {}
197
198mod sealed {
199    pub trait Sealed {}
200}