sqlx_core_oldapi/postgres/connection/
stream.rs

1use std::collections::BTreeMap;
2use std::ops::{Deref, DerefMut};
3use std::str::FromStr;
4
5use bytes::{Buf, Bytes};
6use futures_channel::mpsc::UnboundedSender;
7use futures_util::SinkExt;
8use log::Level;
9
10use crate::error::Error;
11use crate::io::{BufStream, Decode, Encode};
12use crate::net::{MaybeTlsStream, Socket};
13use crate::postgres::message::{Message, MessageFormat, Notice, Notification, ParameterStatus};
14use crate::postgres::{PgConnectOptions, PgDatabaseError, PgSeverity};
15
16// the stream is a separate type from the connection to uphold the invariant where an instantiated
17// [PgConnection] is a **valid** connection to postgres
18
19// when a new connection is asked for, we work directly on the [PgStream] type until the
20// connection is fully established
21
22// in other words, `self` in any PgConnection method is a live connection to postgres that
23// is fully prepared to receive queries
24
25pub struct PgStream {
26    inner: BufStream<MaybeTlsStream<Socket>>,
27
28    // buffer of unreceived notification messages from `PUBLISH`
29    // this is set when creating a PgListener and only written to if that listener is
30    // re-used for query execution in-between receiving messages
31    pub(crate) notifications: Option<UnboundedSender<Notification>>,
32
33    pub(crate) parameter_statuses: BTreeMap<String, String>,
34
35    pub(crate) server_version_num: Option<u32>,
36}
37
38impl PgStream {
39    pub(super) async fn connect(options: &PgConnectOptions) -> Result<Self, Error> {
40        let socket = match options.fetch_socket() {
41            Some(ref path) => Socket::connect_uds(path).await?,
42            None => Socket::connect_tcp(&options.host, options.port).await?,
43        };
44
45        let inner = BufStream::new(MaybeTlsStream::Raw(socket));
46
47        Ok(Self {
48            inner,
49            notifications: None,
50            parameter_statuses: BTreeMap::default(),
51            server_version_num: None,
52        })
53    }
54
55    pub(crate) async fn send<'en, T>(&mut self, message: T) -> Result<(), Error>
56    where
57        T: Encode<'en>,
58    {
59        self.write(message);
60        self.flush().await
61    }
62
63    // Expect a specific type and format
64    pub(crate) async fn recv_expect<'de, T: Decode<'de>>(
65        &mut self,
66        format: MessageFormat,
67    ) -> Result<T, Error> {
68        let message = self.recv().await?;
69
70        if message.format != format {
71            return Err(err_protocol!(
72                "expecting {:?} but received {:?}",
73                format,
74                message.format
75            ));
76        }
77
78        message.decode()
79    }
80
81    pub(crate) async fn recv_unchecked(&mut self) -> Result<Message, Error> {
82        // all packets in postgres start with a 5-byte header
83        // this header contains the message type and the total length of the message
84        let mut header: Bytes = self.inner.read(5).await?;
85
86        let format = MessageFormat::try_from_u8(header.get_u8())?;
87        let size = (header.get_u32() - 4) as usize;
88
89        let contents = self.inner.read(size).await?;
90
91        Ok(Message { format, contents })
92    }
93
94    // Get the next message from the server
95    // May wait for more data from the server
96    pub(crate) async fn recv(&mut self) -> Result<Message, Error> {
97        loop {
98            let message = self.recv_unchecked().await?;
99
100            match message.format {
101                MessageFormat::ErrorResponse => {
102                    // An error returned from the database server.
103                    return Err(PgDatabaseError(message.decode()?).into());
104                }
105
106                MessageFormat::NotificationResponse => {
107                    if let Some(buffer) = &mut self.notifications {
108                        let notification: Notification = message.decode()?;
109                        let _ = buffer.send(notification).await;
110
111                        continue;
112                    }
113                }
114
115                MessageFormat::ParameterStatus => {
116                    // informs the frontend about the current (initial)
117                    // setting of backend parameters
118
119                    let ParameterStatus { name, value } = message.decode()?;
120                    // TODO: handle `client_encoding`, `DateStyle` change
121
122                    match name.as_str() {
123                        "server_version" => {
124                            self.server_version_num = parse_server_version(&value);
125                        }
126                        _ => {
127                            self.parameter_statuses.insert(name, value);
128                        }
129                    }
130
131                    continue;
132                }
133
134                MessageFormat::NoticeResponse => {
135                    // do we need this to be more configurable?
136                    // if you are reading this comment and think so, open an issue
137
138                    let notice: Notice = message.decode()?;
139
140                    let lvl = match notice.severity() {
141                        PgSeverity::Fatal | PgSeverity::Panic | PgSeverity::Error => Level::Error,
142                        PgSeverity::Warning => Level::Warn,
143                        PgSeverity::Notice => Level::Info,
144                        PgSeverity::Debug => Level::Debug,
145                        PgSeverity::Info => Level::Trace,
146                        PgSeverity::Log => Level::Trace,
147                    };
148
149                    if log::log_enabled!(target: "sqlx::postgres::notice", lvl) {
150                        log::logger().log(
151                            &log::Record::builder()
152                                .args(format_args!("{}", notice.message()))
153                                .level(lvl)
154                                .module_path_static(Some("sqlx::postgres::notice"))
155                                .target("sqlx::postgres::notice")
156                                .file_static(Some(file!()))
157                                .line(Some(line!()))
158                                .build(),
159                        );
160                    }
161
162                    continue;
163                }
164
165                _ => {}
166            }
167
168            return Ok(message);
169        }
170    }
171}
172
173impl Deref for PgStream {
174    type Target = BufStream<MaybeTlsStream<Socket>>;
175
176    #[inline]
177    fn deref(&self) -> &Self::Target {
178        &self.inner
179    }
180}
181
182impl DerefMut for PgStream {
183    #[inline]
184    fn deref_mut(&mut self) -> &mut Self::Target {
185        &mut self.inner
186    }
187}
188
189// reference:
190// https://github.com/postgres/postgres/blob/6feebcb6b44631c3dc435e971bd80c2dd218a5ab/src/interfaces/libpq/fe-exec.c#L1030-L1065
191fn parse_server_version(s: &str) -> Option<u32> {
192    let mut parts = Vec::<u32>::with_capacity(3);
193
194    let mut from = 0;
195    let mut chs = s.char_indices().peekable();
196    while let Some((i, ch)) = chs.next() {
197        match ch {
198            '.' => {
199                if let Ok(num) = u32::from_str(&s[from..i]) {
200                    parts.push(num);
201                    from = i + 1;
202                } else {
203                    break;
204                }
205            }
206            _ if ch.is_digit(10) => {
207                if chs.peek().is_none() {
208                    if let Ok(num) = u32::from_str(&s[from..]) {
209                        parts.push(num);
210                    }
211                    break;
212                }
213            }
214            _ => {
215                if let Ok(num) = u32::from_str(&s[from..i]) {
216                    parts.push(num);
217                }
218                break;
219            }
220        };
221    }
222
223    let version_num = match parts.as_slice() {
224        [major, minor, rev] => (100 * major + minor) * 100 + rev,
225        [major, minor] if *major >= 10 => 100 * 100 * major + minor,
226        [major, minor] => (100 * major + minor) * 100,
227        [major] => 100 * 100 * major,
228        _ => return None,
229    };
230
231    Some(version_num)
232}
233
234#[cfg(test)]
235mod tests {
236    use super::parse_server_version;
237
238    #[test]
239    fn test_parse_server_version_num() {
240        // old style
241        assert_eq!(parse_server_version("9.6.1"), Some(90601));
242        // new style
243        assert_eq!(parse_server_version("10.1"), Some(100001));
244        // old style without minor version
245        assert_eq!(parse_server_version("9.6devel"), Some(90600));
246        // new style without minor version, e.g.  */
247        assert_eq!(parse_server_version("10devel"), Some(100000));
248        assert_eq!(parse_server_version("13devel87"), Some(130000));
249        // unknown
250        assert_eq!(parse_server_version("unknown"), None);
251    }
252}