sqlx_etorreborre_postgres/connection/
stream.rs

1use std::collections::BTreeMap;
2use std::ops::{Deref, DerefMut};
3use std::str::FromStr;
4
5use futures_channel::mpsc::UnboundedSender;
6use futures_util::SinkExt;
7use log::Level;
8use sqlx_core::bytes::{Buf, Bytes};
9
10use crate::connection::tls::MaybeUpgradeTls;
11use crate::error::Error;
12use crate::io::{Decode, Encode};
13use crate::message::{Message, MessageFormat, Notice, Notification, ParameterStatus};
14use crate::net::{self, BufferedSocket, Socket};
15use crate::{PgConnectOptions, PgDatabaseError, PgSeverity};
16
17// the stream is a separate type from the connection to uphold the invariant where an instantiated
18// [PgConnection] is a **valid** connection to postgres
19
20// when a new connection is asked for, we work directly on the [PgStream] type until the
21// connection is fully established
22
23// in other words, `self` in any PgConnection method is a live connection to postgres that
24// is fully prepared to receive queries
25
26pub struct PgStream {
27    // A trait object is okay here as the buffering amortizes the overhead of both the dynamic
28    // function call as well as the syscall.
29    inner: BufferedSocket<Box<dyn Socket>>,
30
31    // buffer of unreceived notification messages from `PUBLISH`
32    // this is set when creating a PgListener and only written to if that listener is
33    // re-used for query execution in-between receiving messages
34    pub(crate) notifications: Option<UnboundedSender<Notification>>,
35
36    pub(crate) parameter_statuses: BTreeMap<String, String>,
37
38    pub(crate) server_version_num: Option<u32>,
39}
40
41impl PgStream {
42    pub(super) async fn connect(options: &PgConnectOptions) -> Result<Self, Error> {
43        let socket_future = match options.fetch_socket() {
44            Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?,
45            None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?,
46        };
47
48        let socket = socket_future.await?;
49
50        Ok(Self {
51            inner: BufferedSocket::new(socket),
52            notifications: None,
53            parameter_statuses: BTreeMap::default(),
54            server_version_num: None,
55        })
56    }
57
58    pub(crate) async fn send<'en, T>(&mut self, message: T) -> Result<(), Error>
59    where
60        T: Encode<'en>,
61    {
62        self.write(message);
63        self.flush().await?;
64        Ok(())
65    }
66
67    // Expect a specific type and format
68    pub(crate) async fn recv_expect<'de, T: Decode<'de>>(
69        &mut self,
70        format: MessageFormat,
71    ) -> Result<T, Error> {
72        let message = self.recv().await?;
73
74        if message.format != format {
75            return Err(err_protocol!(
76                "expecting {:?} but received {:?}",
77                format,
78                message.format
79            ));
80        }
81
82        message.decode()
83    }
84
85    pub(crate) async fn recv_unchecked(&mut self) -> Result<Message, Error> {
86        // all packets in postgres start with a 5-byte header
87        // this header contains the message type and the total length of the message
88        let mut header: Bytes = self.inner.read(5).await?;
89
90        let format = MessageFormat::try_from_u8(header.get_u8())?;
91        let size = (header.get_u32() - 4) as usize;
92
93        let contents = self.inner.read(size).await?;
94
95        Ok(Message { format, contents })
96    }
97
98    // Get the next message from the server
99    // May wait for more data from the server
100    pub(crate) async fn recv(&mut self) -> Result<Message, Error> {
101        loop {
102            let message = self.recv_unchecked().await?;
103
104            match message.format {
105                MessageFormat::ErrorResponse => {
106                    // An error returned from the database server.
107                    return Err(PgDatabaseError(message.decode()?).into());
108                }
109
110                MessageFormat::NotificationResponse => {
111                    if let Some(buffer) = &mut self.notifications {
112                        let notification: Notification = message.decode()?;
113                        let _ = buffer.send(notification).await;
114
115                        continue;
116                    }
117                }
118
119                MessageFormat::ParameterStatus => {
120                    // informs the frontend about the current (initial)
121                    // setting of backend parameters
122
123                    let ParameterStatus { name, value } = message.decode()?;
124                    // TODO: handle `client_encoding`, `DateStyle` change
125
126                    match name.as_str() {
127                        "server_version" => {
128                            self.server_version_num = parse_server_version(&value);
129                        }
130                        _ => {
131                            self.parameter_statuses.insert(name, value);
132                        }
133                    }
134
135                    continue;
136                }
137
138                MessageFormat::NoticeResponse => {
139                    // do we need this to be more configurable?
140                    // if you are reading this comment and think so, open an issue
141
142                    let notice: Notice = message.decode()?;
143
144                    let (log_level, tracing_level) = match notice.severity() {
145                        PgSeverity::Fatal | PgSeverity::Panic | PgSeverity::Error => {
146                            (Level::Error, tracing::Level::ERROR)
147                        }
148                        PgSeverity::Warning => (Level::Warn, tracing::Level::WARN),
149                        PgSeverity::Notice => (Level::Info, tracing::Level::INFO),
150                        PgSeverity::Debug => (Level::Debug, tracing::Level::DEBUG),
151                        PgSeverity::Info | PgSeverity::Log => (Level::Trace, tracing::Level::TRACE),
152                    };
153
154                    let log_is_enabled = log::log_enabled!(
155                        target: "sqlx::postgres::notice",
156                        log_level
157                    ) || sqlx_core::private_tracing_dynamic_enabled!(
158                        target: "sqlx::postgres::notice",
159                        tracing_level
160                    );
161                    if log_is_enabled {
162                        let message = format!("{}", notice.message());
163                        sqlx_core::private_tracing_dynamic_event!(
164                            target: "sqlx::postgres::notice",
165                            tracing_level,
166                            message
167                        );
168                    }
169
170                    continue;
171                }
172
173                _ => {}
174            }
175
176            return Ok(message);
177        }
178    }
179}
180
181impl Deref for PgStream {
182    type Target = BufferedSocket<Box<dyn Socket>>;
183
184    #[inline]
185    fn deref(&self) -> &Self::Target {
186        &self.inner
187    }
188}
189
190impl DerefMut for PgStream {
191    #[inline]
192    fn deref_mut(&mut self) -> &mut Self::Target {
193        &mut self.inner
194    }
195}
196
197// reference:
198// https://github.com/postgres/postgres/blob/6feebcb6b44631c3dc435e971bd80c2dd218a5ab/src/interfaces/libpq/fe-exec.c#L1030-L1065
199fn parse_server_version(s: &str) -> Option<u32> {
200    let mut parts = Vec::<u32>::with_capacity(3);
201
202    let mut from = 0;
203    let mut chs = s.char_indices().peekable();
204    while let Some((i, ch)) = chs.next() {
205        match ch {
206            '.' => {
207                if let Ok(num) = u32::from_str(&s[from..i]) {
208                    parts.push(num);
209                    from = i + 1;
210                } else {
211                    break;
212                }
213            }
214            _ if ch.is_digit(10) => {
215                if chs.peek().is_none() {
216                    if let Ok(num) = u32::from_str(&s[from..]) {
217                        parts.push(num);
218                    }
219                    break;
220                }
221            }
222            _ => {
223                if let Ok(num) = u32::from_str(&s[from..i]) {
224                    parts.push(num);
225                }
226                break;
227            }
228        };
229    }
230
231    let version_num = match parts.as_slice() {
232        [major, minor, rev] => (100 * major + minor) * 100 + rev,
233        [major, minor] if *major >= 10 => 100 * 100 * major + minor,
234        [major, minor] => (100 * major + minor) * 100,
235        [major] => 100 * 100 * major,
236        _ => return None,
237    };
238
239    Some(version_num)
240}
241
242#[cfg(test)]
243mod tests {
244    use super::parse_server_version;
245
246    #[test]
247    fn test_parse_server_version_num() {
248        // old style
249        assert_eq!(parse_server_version("9.6.1"), Some(90601));
250        // new style
251        assert_eq!(parse_server_version("10.1"), Some(100001));
252        // old style without minor version
253        assert_eq!(parse_server_version("9.6devel"), Some(90600));
254        // new style without minor version, e.g.  */
255        assert_eq!(parse_server_version("10devel"), Some(100000));
256        assert_eq!(parse_server_version("13devel87"), Some(130000));
257        // unknown
258        assert_eq!(parse_server_version("unknown"), None);
259    }
260}