xitca_postgres/
session.rs

1//! session handling after server connection is established with authentication and credential info.
2
3use core::net::SocketAddr;
4
5use fallible_iterator::FallibleIterator;
6use postgres_protocol::{
7    authentication::{self, sasl},
8    message::{backend, frontend},
9};
10use xitca_io::{bytes::BytesMut, io::AsyncIo};
11
12use super::{
13    config::{Config, SslMode, SslNegotiation},
14    driver::generic::GenericDriver,
15    error::{AuthenticationError, Error},
16};
17
18/// Properties required of a session.
19#[derive(Copy, Clone, Debug, Eq, PartialEq)]
20#[non_exhaustive]
21pub enum TargetSessionAttrs {
22    /// No special properties are required.
23    Any,
24    /// The session must allow writes.
25    ReadWrite,
26    /// The session only allows read.
27    ReadOnly,
28}
29
30/// information about session. used for canceling query
31#[derive(Clone)]
32pub struct Session {
33    pub(crate) id: i32,
34    pub(crate) key: i32,
35    pub(crate) info: ConnectInfo,
36}
37
38#[derive(Clone, Default)]
39pub(crate) struct ConnectInfo {
40    pub(crate) addr: Addr,
41    pub(crate) ssl_mode: SslMode,
42    pub(crate) ssl_negotiation: SslNegotiation,
43}
44
45impl ConnectInfo {
46    pub(crate) fn new(addr: Addr, ssl_mode: SslMode, ssl_negotiation: SslNegotiation) -> Self {
47        Self {
48            addr,
49            ssl_mode,
50            ssl_negotiation,
51        }
52    }
53}
54
55#[derive(Clone, Default)]
56pub(crate) enum Addr {
57    Tcp(Box<str>, SocketAddr),
58    #[cfg(unix)]
59    Unix(Box<str>, std::path::PathBuf),
60    #[cfg(feature = "quic")]
61    Quic(Box<str>, SocketAddr),
62    // case for where io is supplied by user and no connectivity can be done from this crate
63    #[default]
64    None,
65}
66
67impl Session {
68    fn new(info: ConnectInfo) -> Self {
69        Self { id: 0, key: 0, info }
70    }
71}
72
73impl Session {
74    #[allow(clippy::needless_pass_by_ref_mut)] // dumb clippy
75    #[cold]
76    #[inline(never)]
77    pub(super) async fn prepare_session<Io>(
78        info: ConnectInfo,
79        drv: &mut GenericDriver<Io>,
80        cfg: &Config,
81    ) -> Result<Self, Error>
82    where
83        Io: AsyncIo + Send,
84    {
85        let mut buf = BytesMut::new();
86
87        auth(drv, cfg, &mut buf).await?;
88
89        let mut session = Session::new(info);
90
91        loop {
92            match drv.recv().await? {
93                backend::Message::ReadyForQuery(_) => break,
94                backend::Message::BackendKeyData(body) => {
95                    session.id = body.process_id();
96                    session.key = body.secret_key();
97                }
98                backend::Message::ParameterStatus(body) => {
99                    // TODO: handling params?
100                    let _name = body.name()?;
101                    let _value = body.value()?;
102                }
103                backend::Message::ErrorResponse(body) => return Err(Error::db(body.fields())),
104                backend::Message::NoticeResponse(_) => {
105                    // TODO: collect notice and let Driver emit it when polled?
106                }
107                _ => return Err(Error::unexpected()),
108            }
109        }
110
111        if !matches!(cfg.get_target_session_attrs(), TargetSessionAttrs::Any) {
112            frontend::query("SHOW transaction_read_only", &mut buf)?;
113            let msg = buf.split();
114            drv.send(msg).await?;
115            // TODO: use RowSimple for parsing?
116            loop {
117                match drv.recv().await? {
118                    backend::Message::DataRow(body) => {
119                        let range = body.ranges().next()?.flatten().ok_or(Error::todo())?;
120                        let slice = &body.buffer()[range.start..range.end];
121                        match (slice, cfg.get_target_session_attrs()) {
122                            (b"on", TargetSessionAttrs::ReadWrite) => return Err(Error::todo()),
123                            (b"off", TargetSessionAttrs::ReadOnly) => return Err(Error::todo()),
124                            _ => {}
125                        }
126                    }
127                    backend::Message::RowDescription(_) | backend::Message::CommandComplete(_) => {}
128                    backend::Message::EmptyQueryResponse | backend::Message::ReadyForQuery(_) => break,
129                    _ => return Err(Error::unexpected()),
130                }
131            }
132        }
133
134        Ok(session)
135    }
136}
137
138#[cold]
139#[inline(never)]
140async fn auth<Io>(drv: &mut GenericDriver<Io>, cfg: &Config, buf: &mut BytesMut) -> Result<(), Error>
141where
142    Io: AsyncIo + Send,
143{
144    let mut params = vec![("client_encoding", "UTF8")];
145    if let Some(user) = &cfg.user {
146        params.push(("user", &**user));
147    }
148    if let Some(dbname) = &cfg.dbname {
149        params.push(("database", &**dbname));
150    }
151    if let Some(options) = &cfg.options {
152        params.push(("options", &**options));
153    }
154    if let Some(application_name) = &cfg.application_name {
155        params.push(("application_name", &**application_name));
156    }
157
158    frontend::startup_message(params, buf)?;
159    let msg = buf.split();
160    drv.send(msg).await?;
161
162    loop {
163        match drv.recv().await? {
164            backend::Message::AuthenticationOk => return Ok(()),
165            backend::Message::AuthenticationCleartextPassword => {
166                let pass = cfg.get_password().ok_or(AuthenticationError::MissingPassWord)?;
167                send_pass(drv, pass, buf).await?;
168            }
169            backend::Message::AuthenticationMd5Password(body) => {
170                let pass = cfg.get_password().ok_or(AuthenticationError::MissingPassWord)?;
171                let user = cfg.get_user().ok_or(AuthenticationError::MissingUserName)?.as_bytes();
172                let pass = authentication::md5_hash(user, pass, body.salt());
173                send_pass(drv, pass, buf).await?;
174            }
175            backend::Message::AuthenticationSasl(body) => {
176                let pass = cfg.get_password().ok_or(AuthenticationError::MissingPassWord)?;
177
178                let mut is_scram = false;
179                let mut is_scram_plus = false;
180                let mut mechanisms = body.mechanisms();
181
182                while let Some(mechanism) = mechanisms.next()? {
183                    match mechanism {
184                        sasl::SCRAM_SHA_256 => is_scram = true,
185                        sasl::SCRAM_SHA_256_PLUS => is_scram_plus = true,
186                        _ => {}
187                    }
188                }
189
190                let (channel_binding, mechanism) = match (is_scram_plus, is_scram) {
191                    (true, is_scram) => {
192                        let buf = cfg.get_tls_server_end_point();
193                        match (buf, is_scram) {
194                            (Some(buf), _) => (
195                                sasl::ChannelBinding::tls_server_end_point(buf.to_owned()),
196                                sasl::SCRAM_SHA_256_PLUS,
197                            ),
198                            (None, true) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
199                            // server ask for channel binding but no tls_server_end_point can be
200                            // found.
201                            _ => return Err(Error::todo()),
202                        }
203                    }
204                    (false, true) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
205                    // TODO: return "unsupported SASL mechanism" error.
206                    (false, false) => return Err(Error::todo()),
207                };
208
209                let mut scram = sasl::ScramSha256::new(pass, channel_binding);
210
211                frontend::sasl_initial_response(mechanism, scram.message(), buf)?;
212                let msg = buf.split();
213                drv.send(msg).await?;
214
215                match drv.recv().await? {
216                    backend::Message::AuthenticationSaslContinue(body) => {
217                        scram.update(body.data())?;
218                        frontend::sasl_response(scram.message(), buf)?;
219                        let msg = buf.split();
220                        drv.send(msg).await?;
221                    }
222                    _ => return Err(Error::todo()),
223                }
224
225                match drv.recv().await? {
226                    backend::Message::AuthenticationSaslFinal(body) => scram.finish(body.data())?,
227                    _ => return Err(Error::todo()),
228                }
229            }
230            backend::Message::ErrorResponse(_) => return Err(Error::from(AuthenticationError::WrongPassWord)),
231            _ => {}
232        }
233    }
234}
235
236async fn send_pass<Io>(drv: &mut GenericDriver<Io>, pass: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), Error>
237where
238    Io: AsyncIo + Send,
239{
240    frontend::password_message(pass.as_ref(), buf)?;
241    let msg = buf.split();
242    drv.send(msg).await
243}