Skip to main content

zero_postgres/protocol/backend/
auth.rs

1//! Authentication-related backend messages.
2
3use zerocopy::{FromBytes, Immutable, KnownLayout};
4
5use crate::error::{Error, Result};
6use crate::protocol::codec::{read_cstr, read_i32, read_u32};
7use crate::protocol::types::TransactionStatus;
8
9/// Authentication method constants.
10pub mod auth_type {
11    pub const OK: i32 = 0;
12    pub const KERBEROS_V5: i32 = 2;
13    pub const CLEARTEXT_PASSWORD: i32 = 3;
14    pub const MD5_PASSWORD: i32 = 5;
15    pub const GSS: i32 = 7;
16    pub const GSS_CONTINUE: i32 = 8;
17    pub const SSPI: i32 = 9;
18    pub const SASL: i32 = 10;
19    pub const SASL_CONTINUE: i32 = 11;
20    pub const SASL_FINAL: i32 = 12;
21}
22
23/// Authentication message from the server.
24#[derive(Debug)]
25pub enum AuthenticationMessage<'a> {
26    /// Authentication successful
27    Ok,
28    /// Kerberos V5 authentication required
29    KerberosV5,
30    /// Cleartext password required
31    CleartextPassword,
32    /// MD5 password required (with 4-byte salt)
33    Md5Password { salt: [u8; 4] },
34    /// GSS authentication
35    Gss,
36    /// GSS continue (with additional data)
37    GssContinue { data: &'a [u8] },
38    /// SSPI authentication
39    Sspi,
40    /// SASL authentication required (with list of mechanisms)
41    Sasl { mechanisms: Vec<&'a str> },
42    /// SASL continue (with server-first-message)
43    SaslContinue { data: &'a [u8] },
44    /// SASL final (with server-final-message)
45    SaslFinal { data: &'a [u8] },
46}
47
48impl<'a> AuthenticationMessage<'a> {
49    /// Parse an Authentication message from payload bytes.
50    pub fn parse(payload: &'a [u8]) -> Result<Self> {
51        let (auth_type, rest) = read_i32(payload)?;
52
53        match auth_type {
54            auth_type::OK => Ok(AuthenticationMessage::Ok),
55            auth_type::KERBEROS_V5 => Ok(AuthenticationMessage::KerberosV5),
56            auth_type::CLEARTEXT_PASSWORD => Ok(AuthenticationMessage::CleartextPassword),
57            auth_type::MD5_PASSWORD => {
58                if rest.len() < 4 {
59                    return Err(Error::LibraryBug("MD5Password: missing salt".into()));
60                }
61                let mut salt = [0u8; 4];
62                salt.copy_from_slice(&rest[..4]);
63                Ok(AuthenticationMessage::Md5Password { salt })
64            }
65            auth_type::GSS => Ok(AuthenticationMessage::Gss),
66            auth_type::GSS_CONTINUE => Ok(AuthenticationMessage::GssContinue { data: rest }),
67            auth_type::SSPI => Ok(AuthenticationMessage::Sspi),
68            auth_type::SASL => {
69                let mut mechanisms = Vec::new();
70                let mut data = rest;
71                while !data.is_empty() && data[0] != 0 {
72                    let (mechanism, remaining) = read_cstr(data)?;
73                    mechanisms.push(mechanism);
74                    data = remaining;
75                }
76                Ok(AuthenticationMessage::Sasl { mechanisms })
77            }
78            auth_type::SASL_CONTINUE => Ok(AuthenticationMessage::SaslContinue { data: rest }),
79            auth_type::SASL_FINAL => Ok(AuthenticationMessage::SaslFinal { data: rest }),
80            _ => Err(Error::LibraryBug(format!(
81                "Unknown authentication type: {}",
82                auth_type
83            ))),
84        }
85    }
86}
87
88/// BackendKeyData message - contains process ID and secret key for cancellation.
89///
90/// In protocol 3.2, the secret key is variable-length (4-256 bytes).
91#[derive(Debug, Clone)]
92pub struct BackendKeyData {
93    /// Process ID of the backend
94    pid: u32,
95    /// Secret key for cancellation (variable length in protocol 3.2)
96    secret_key: Vec<u8>,
97}
98
99impl BackendKeyData {
100    /// Parse a BackendKeyData message from payload bytes.
101    pub fn parse(payload: &[u8]) -> Result<Self> {
102        if payload.len() < 4 {
103            return Err(Error::LibraryBug(
104                "BackendKeyData: payload too short".into(),
105            ));
106        }
107        let (pid, rest) = read_u32(payload)?;
108        if rest.len() < 4 || rest.len() > 256 {
109            return Err(Error::LibraryBug(format!(
110                "BackendKeyData: invalid secret key length {}",
111                rest.len()
112            )));
113        }
114        Ok(Self {
115            pid,
116            secret_key: rest.to_vec(),
117        })
118    }
119
120    /// Get the process ID.
121    pub fn process_id(&self) -> u32 {
122        self.pid
123    }
124
125    /// Get the secret key bytes.
126    pub fn secret_key(&self) -> &[u8] {
127        &self.secret_key
128    }
129}
130
131/// ParameterStatus message - server parameter name and value.
132#[derive(Debug, Clone)]
133pub struct ParameterStatus<'a> {
134    /// Parameter name
135    pub name: &'a str,
136    /// Parameter value
137    pub value: &'a str,
138}
139
140impl<'a> ParameterStatus<'a> {
141    /// Parse a ParameterStatus message from payload bytes.
142    pub fn parse(payload: &'a [u8]) -> Result<Self> {
143        let (name, rest) = read_cstr(payload)?;
144        let (value, _) = read_cstr(rest)?;
145        Ok(Self { name, value })
146    }
147}
148
149/// ReadyForQuery message - indicates server is ready for a new query.
150#[derive(Debug, Clone, Copy, FromBytes, KnownLayout, Immutable)]
151#[repr(C, packed)]
152pub struct ReadyForQuery {
153    /// Transaction status byte
154    pub status: u8,
155}
156
157impl ReadyForQuery {
158    /// Parse a ReadyForQuery message from payload bytes.
159    pub fn parse(payload: &[u8]) -> Result<&Self> {
160        Self::ref_from_bytes(payload)
161            .map_err(|e| Error::LibraryBug(format!("ReadyForQuery: {e:?}")))
162    }
163
164    /// Get the transaction status.
165    pub fn transaction_status(&self) -> Option<TransactionStatus> {
166        TransactionStatus::from_byte(self.status)
167    }
168}
169
170/// NotificationResponse message - asynchronous notification from LISTEN/NOTIFY.
171#[derive(Debug, Clone)]
172pub struct NotificationResponse<'a> {
173    /// PID of the notifying backend
174    pub pid: u32,
175    /// Channel name
176    pub channel: &'a str,
177    /// Notification payload
178    pub payload: &'a str,
179}
180
181impl<'a> NotificationResponse<'a> {
182    /// Parse a NotificationResponse message from payload bytes.
183    pub fn parse(payload: &'a [u8]) -> Result<Self> {
184        let (pid, rest) = read_u32(payload)?;
185        let (channel, rest) = read_cstr(rest)?;
186        let (payload_str, _) = read_cstr(rest)?;
187        Ok(Self {
188            pid,
189            channel,
190            payload: payload_str,
191        })
192    }
193}
194
195/// NegotiateProtocolVersion message - server doesn't support requested protocol features.
196#[derive(Debug, Clone)]
197pub struct NegotiateProtocolVersion<'a> {
198    /// Newest minor protocol version supported
199    pub newest_minor_version: u32,
200    /// Unrecognized protocol options
201    pub unrecognized_options: Vec<&'a str>,
202}
203
204impl<'a> NegotiateProtocolVersion<'a> {
205    /// Parse a NegotiateProtocolVersion message from payload bytes.
206    pub fn parse(payload: &'a [u8]) -> Result<Self> {
207        let (newest_minor_version, rest) = read_u32(payload)?;
208        let (num_options, mut rest) = read_u32(rest)?;
209
210        let mut unrecognized_options = Vec::with_capacity(num_options as usize);
211        for _ in 0..num_options {
212            let (option, remaining) = read_cstr(rest)?;
213            unrecognized_options.push(option);
214            rest = remaining;
215        }
216
217        Ok(Self {
218            newest_minor_version,
219            unrecognized_options,
220        })
221    }
222}