sqlx_core_oldapi/postgres/connection/
stream.rs1use std::collections::BTreeMap;
2use std::ops::{Deref, DerefMut};
3use std::str::FromStr;
4
5use bytes::{Buf, Bytes};
6use futures_channel::mpsc::UnboundedSender;
7use futures_util::SinkExt;
8use log::Level;
9
10use crate::error::Error;
11use crate::io::{BufStream, Decode, Encode};
12use crate::net::{MaybeTlsStream, Socket};
13use crate::postgres::message::{Message, MessageFormat, Notice, Notification, ParameterStatus};
14use crate::postgres::{PgConnectOptions, PgDatabaseError, PgSeverity};
15
16pub struct PgStream {
26 inner: BufStream<MaybeTlsStream<Socket>>,
27
28 pub(crate) notifications: Option<UnboundedSender<Notification>>,
32
33 pub(crate) parameter_statuses: BTreeMap<String, String>,
34
35 pub(crate) server_version_num: Option<u32>,
36}
37
38impl PgStream {
39 pub(super) async fn connect(options: &PgConnectOptions) -> Result<Self, Error> {
40 let socket = match options.fetch_socket() {
41 Some(ref path) => Socket::connect_uds(path).await?,
42 None => Socket::connect_tcp(&options.host, options.port).await?,
43 };
44
45 let inner = BufStream::new(MaybeTlsStream::Raw(socket));
46
47 Ok(Self {
48 inner,
49 notifications: None,
50 parameter_statuses: BTreeMap::default(),
51 server_version_num: None,
52 })
53 }
54
55 pub(crate) async fn send<'en, T>(&mut self, message: T) -> Result<(), Error>
56 where
57 T: Encode<'en>,
58 {
59 self.write(message);
60 self.flush().await
61 }
62
63 pub(crate) async fn recv_expect<'de, T: Decode<'de>>(
65 &mut self,
66 format: MessageFormat,
67 ) -> Result<T, Error> {
68 let message = self.recv().await?;
69
70 if message.format != format {
71 return Err(err_protocol!(
72 "expecting {:?} but received {:?}",
73 format,
74 message.format
75 ));
76 }
77
78 message.decode()
79 }
80
81 pub(crate) async fn recv_unchecked(&mut self) -> Result<Message, Error> {
82 let mut header: Bytes = self.inner.read(5).await?;
85
86 let format = MessageFormat::try_from_u8(header.get_u8())?;
87 let size = (header.get_u32() - 4) as usize;
88
89 let contents = self.inner.read(size).await?;
90
91 Ok(Message { format, contents })
92 }
93
94 pub(crate) async fn recv(&mut self) -> Result<Message, Error> {
97 loop {
98 let message = self.recv_unchecked().await?;
99
100 match message.format {
101 MessageFormat::ErrorResponse => {
102 return Err(PgDatabaseError(message.decode()?).into());
104 }
105
106 MessageFormat::NotificationResponse => {
107 if let Some(buffer) = &mut self.notifications {
108 let notification: Notification = message.decode()?;
109 let _ = buffer.send(notification).await;
110
111 continue;
112 }
113 }
114
115 MessageFormat::ParameterStatus => {
116 let ParameterStatus { name, value } = message.decode()?;
120 match name.as_str() {
123 "server_version" => {
124 self.server_version_num = parse_server_version(&value);
125 }
126 _ => {
127 self.parameter_statuses.insert(name, value);
128 }
129 }
130
131 continue;
132 }
133
134 MessageFormat::NoticeResponse => {
135 let notice: Notice = message.decode()?;
139
140 let lvl = match notice.severity() {
141 PgSeverity::Fatal | PgSeverity::Panic | PgSeverity::Error => Level::Error,
142 PgSeverity::Warning => Level::Warn,
143 PgSeverity::Notice => Level::Info,
144 PgSeverity::Debug => Level::Debug,
145 PgSeverity::Info => Level::Trace,
146 PgSeverity::Log => Level::Trace,
147 };
148
149 if log::log_enabled!(target: "sqlx::postgres::notice", lvl) {
150 log::logger().log(
151 &log::Record::builder()
152 .args(format_args!("{}", notice.message()))
153 .level(lvl)
154 .module_path_static(Some("sqlx::postgres::notice"))
155 .target("sqlx::postgres::notice")
156 .file_static(Some(file!()))
157 .line(Some(line!()))
158 .build(),
159 );
160 }
161
162 continue;
163 }
164
165 _ => {}
166 }
167
168 return Ok(message);
169 }
170 }
171}
172
173impl Deref for PgStream {
174 type Target = BufStream<MaybeTlsStream<Socket>>;
175
176 #[inline]
177 fn deref(&self) -> &Self::Target {
178 &self.inner
179 }
180}
181
182impl DerefMut for PgStream {
183 #[inline]
184 fn deref_mut(&mut self) -> &mut Self::Target {
185 &mut self.inner
186 }
187}
188
189fn parse_server_version(s: &str) -> Option<u32> {
192 let mut parts = Vec::<u32>::with_capacity(3);
193
194 let mut from = 0;
195 let mut chs = s.char_indices().peekable();
196 while let Some((i, ch)) = chs.next() {
197 match ch {
198 '.' => {
199 if let Ok(num) = u32::from_str(&s[from..i]) {
200 parts.push(num);
201 from = i + 1;
202 } else {
203 break;
204 }
205 }
206 _ if ch.is_digit(10) => {
207 if chs.peek().is_none() {
208 if let Ok(num) = u32::from_str(&s[from..]) {
209 parts.push(num);
210 }
211 break;
212 }
213 }
214 _ => {
215 if let Ok(num) = u32::from_str(&s[from..i]) {
216 parts.push(num);
217 }
218 break;
219 }
220 };
221 }
222
223 let version_num = match parts.as_slice() {
224 [major, minor, rev] => (100 * major + minor) * 100 + rev,
225 [major, minor] if *major >= 10 => 100 * 100 * major + minor,
226 [major, minor] => (100 * major + minor) * 100,
227 [major] => 100 * 100 * major,
228 _ => return None,
229 };
230
231 Some(version_num)
232}
233
234#[cfg(test)]
235mod tests {
236 use super::parse_server_version;
237
238 #[test]
239 fn test_parse_server_version_num() {
240 assert_eq!(parse_server_version("9.6.1"), Some(90601));
242 assert_eq!(parse_server_version("10.1"), Some(100001));
244 assert_eq!(parse_server_version("9.6devel"), Some(90600));
246 assert_eq!(parse_server_version("10devel"), Some(100000));
248 assert_eq!(parse_server_version("13devel87"), Some(130000));
249 assert_eq!(parse_server_version("unknown"), None);
251 }
252}