Skip to main content

sqlmodel_postgres/protocol/
reader.rs

1//! PostgreSQL message decoder.
2//!
3//! This module handles decoding backend messages from the wire protocol format.
4
5#![allow(clippy::cast_possible_truncation)]
6
7use super::messages::{
8    BackendMessage, ErrorFields, FieldDescription, TransactionStatus, auth_type, backend_type,
9};
10use std::error::Error as StdError;
11use std::fmt;
12
13/// Errors that can occur while decoding PostgreSQL protocol messages.
14#[derive(Debug)]
15pub enum ProtocolError {
16    /// Not enough bytes to parse a full message.
17    Incomplete,
18    /// Invalid length prefix encountered.
19    InvalidLength { length: i32 },
20    /// Message exceeds configured maximum size.
21    MessageTooLarge { length: usize, max: usize },
22    /// Unknown message type byte.
23    UnknownMessageType(u8),
24    /// UTF-8 decoding error while parsing strings.
25    Utf8(std::string::FromUtf8Error),
26    /// Unexpected end of buffer while parsing a field.
27    UnexpectedEof,
28    /// Invalid field encoding or value.
29    InvalidField(&'static str),
30}
31
32impl fmt::Display for ProtocolError {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        match self {
35            ProtocolError::Incomplete => write!(f, "incomplete message"),
36            ProtocolError::InvalidLength { length } => {
37                write!(f, "invalid message length: {}", length)
38            }
39            ProtocolError::MessageTooLarge { length, max } => {
40                write!(f, "message too large: {} > {}", length, max)
41            }
42            ProtocolError::UnknownMessageType(ty) => {
43                write!(f, "unknown message type: 0x{:02x}", ty)
44            }
45            ProtocolError::Utf8(err) => write!(f, "utf-8 error: {}", err),
46            ProtocolError::UnexpectedEof => write!(f, "unexpected end of buffer"),
47            ProtocolError::InvalidField(msg) => write!(f, "invalid field: {}", msg),
48        }
49    }
50}
51
52impl StdError for ProtocolError {}
53
54impl From<std::string::FromUtf8Error> for ProtocolError {
55    fn from(err: std::string::FromUtf8Error) -> Self {
56        ProtocolError::Utf8(err)
57    }
58}
59
60/// Incremental reader for PostgreSQL backend messages.
61#[derive(Debug, Clone)]
62pub struct MessageReader {
63    buf: Vec<u8>,
64    max_message_size: usize,
65}
66
67impl Default for MessageReader {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73impl MessageReader {
74    /// Create a new reader with a default max message size.
75    pub fn new() -> Self {
76        Self::with_max_size(8 * 1024 * 1024)
77    }
78
79    /// Create a new reader with a custom max message size.
80    pub fn with_max_size(max_message_size: usize) -> Self {
81        Self {
82            buf: Vec::new(),
83            max_message_size,
84        }
85    }
86
87    /// Number of bytes currently buffered.
88    pub fn buffered_len(&self) -> usize {
89        self.buf.len()
90    }
91
92    /// Append raw bytes to the internal buffer without parsing.
93    ///
94    /// Use this when the caller will drive parsing via [`next_message()`] in
95    /// its own loop (e.g. `receive_message_no_cx`). This avoids the
96    /// consume-then-discard bug where [`feed()`] parses and returns messages
97    /// that the caller never inspects.
98    pub fn push(&mut self, data: &[u8]) {
99        self.buf.extend_from_slice(data);
100    }
101
102    /// Feed bytes into the reader and return any complete messages.
103    pub fn feed(&mut self, data: &[u8]) -> Result<Vec<BackendMessage>, ProtocolError> {
104        self.buf.extend_from_slice(data);
105
106        let mut messages = Vec::new();
107        while let Some(msg) = self.next_message()? {
108            messages.push(msg);
109        }
110        Ok(messages)
111    }
112
113    /// Attempt to parse the next message from the internal buffer.
114    pub fn next_message(&mut self) -> Result<Option<BackendMessage>, ProtocolError> {
115        if self.buf.len() < 5 {
116            return Ok(None);
117        }
118
119        let length = i32::from_be_bytes([self.buf[1], self.buf[2], self.buf[3], self.buf[4]]);
120        if length < 4 {
121            return Err(ProtocolError::InvalidLength { length });
122        }
123
124        let total_len = length as usize + 1;
125        if total_len > self.max_message_size {
126            return Err(ProtocolError::MessageTooLarge {
127                length: total_len,
128                max: self.max_message_size,
129            });
130        }
131
132        if self.buf.len() < total_len {
133            return Ok(None);
134        }
135
136        let frame = self.buf[..total_len].to_vec();
137        self.buf.drain(..total_len);
138        Ok(Some(Self::parse_message(&frame)?))
139    }
140
141    /// Parse a single full message frame (type + length + payload).
142    pub fn parse_message(frame: &[u8]) -> Result<BackendMessage, ProtocolError> {
143        if frame.len() < 5 {
144            return Err(ProtocolError::Incomplete);
145        }
146
147        let ty = frame[0];
148        let length = i32::from_be_bytes([frame[1], frame[2], frame[3], frame[4]]);
149        if length < 4 {
150            return Err(ProtocolError::InvalidLength { length });
151        }
152
153        let total_len = length as usize + 1;
154        if frame.len() < total_len {
155            return Err(ProtocolError::Incomplete);
156        }
157
158        let payload = &frame[5..total_len];
159        let mut cur = Cursor::new(payload);
160
161        match ty {
162            backend_type::AUTHENTICATION => parse_authentication(&mut cur),
163            backend_type::BACKEND_KEY_DATA => parse_backend_key_data(&mut cur),
164            backend_type::PARAMETER_STATUS => parse_parameter_status(&mut cur),
165            backend_type::READY_FOR_QUERY => parse_ready_for_query(&mut cur),
166            backend_type::ROW_DESCRIPTION => parse_row_description(&mut cur),
167            backend_type::DATA_ROW => parse_data_row(&mut cur),
168            backend_type::COMMAND_COMPLETE => parse_command_complete(&mut cur),
169            backend_type::EMPTY_QUERY => Ok(BackendMessage::EmptyQueryResponse),
170            backend_type::PARSE_COMPLETE => Ok(BackendMessage::ParseComplete),
171            backend_type::BIND_COMPLETE => Ok(BackendMessage::BindComplete),
172            backend_type::CLOSE_COMPLETE => Ok(BackendMessage::CloseComplete),
173            backend_type::PARAMETER_DESCRIPTION => parse_parameter_description(&mut cur),
174            backend_type::NO_DATA => Ok(BackendMessage::NoData),
175            backend_type::PORTAL_SUSPENDED => Ok(BackendMessage::PortalSuspended),
176            backend_type::ERROR_RESPONSE => parse_error_response(&mut cur, true),
177            backend_type::NOTICE_RESPONSE => parse_error_response(&mut cur, false),
178            backend_type::COPY_IN_RESPONSE => parse_copy_in_response(&mut cur),
179            backend_type::COPY_OUT_RESPONSE => parse_copy_out_response(&mut cur),
180            backend_type::COPY_BOTH_RESPONSE => parse_copy_both_response(&mut cur),
181            backend_type::COPY_DATA => Ok(BackendMessage::CopyData(cur.take_remaining())),
182            backend_type::COPY_DONE => Ok(BackendMessage::CopyDone),
183            backend_type::NOTIFICATION_RESPONSE => parse_notification_response(&mut cur),
184            backend_type::FUNCTION_CALL_RESPONSE => parse_function_call_response(&mut cur),
185            backend_type::NEGOTIATE_PROTOCOL_VERSION => parse_negotiate_protocol_version(&mut cur),
186            _ => Err(ProtocolError::UnknownMessageType(ty)),
187        }
188    }
189}
190
191fn parse_authentication(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
192    let auth_type = cur.read_i32()?;
193    match auth_type {
194        auth_type::OK => Ok(BackendMessage::AuthenticationOk),
195        auth_type::CLEARTEXT_PASSWORD => Ok(BackendMessage::AuthenticationCleartextPassword),
196        auth_type::MD5_PASSWORD => {
197            let salt = cur.read_bytes(4)?;
198            let mut buf = [0_u8; 4];
199            buf.copy_from_slice(salt);
200            Ok(BackendMessage::AuthenticationMD5Password(buf))
201        }
202        auth_type::SASL => {
203            let mut mechanisms = Vec::new();
204            loop {
205                let mech = cur.read_cstring()?;
206                if mech.is_empty() {
207                    break;
208                }
209                mechanisms.push(mech);
210            }
211            Ok(BackendMessage::AuthenticationSASL(mechanisms))
212        }
213        auth_type::SASL_CONTINUE => Ok(BackendMessage::AuthenticationSASLContinue(
214            cur.take_remaining(),
215        )),
216        auth_type::SASL_FINAL => Ok(BackendMessage::AuthenticationSASLFinal(
217            cur.take_remaining(),
218        )),
219        _ => Err(ProtocolError::InvalidField("unknown auth type")),
220    }
221}
222
223fn parse_backend_key_data(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
224    let process_id = cur.read_i32()?;
225    let secret_key = cur.read_i32()?;
226    Ok(BackendMessage::BackendKeyData {
227        process_id,
228        secret_key,
229    })
230}
231
232fn parse_parameter_status(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
233    let name = cur.read_cstring()?;
234    let value = cur.read_cstring()?;
235    Ok(BackendMessage::ParameterStatus { name, value })
236}
237
238fn parse_ready_for_query(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
239    let status = cur.read_u8()?;
240    let status = TransactionStatus::from_byte(status)
241        .ok_or(ProtocolError::InvalidField("invalid transaction status"))?;
242    Ok(BackendMessage::ReadyForQuery(status))
243}
244
245fn parse_row_description(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
246    let count = cur.read_i16()?;
247    if count < 0 {
248        return Err(ProtocolError::InvalidField("negative field count"));
249    }
250    let mut fields = Vec::with_capacity(count as usize);
251    for _ in 0..count {
252        let name = cur.read_cstring()?;
253        let table_oid = cur.read_u32()?;
254        let column_id = cur.read_i16()?;
255        let type_oid = cur.read_u32()?;
256        let type_size = cur.read_i16()?;
257        let type_modifier = cur.read_i32()?;
258        let format = cur.read_i16()?;
259        fields.push(FieldDescription {
260            name,
261            table_oid,
262            column_id,
263            type_oid,
264            type_size,
265            type_modifier,
266            format,
267        });
268    }
269    Ok(BackendMessage::RowDescription(fields))
270}
271
272fn parse_data_row(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
273    let count = cur.read_i16()?;
274    if count < 0 {
275        return Err(ProtocolError::InvalidField("negative column count"));
276    }
277    let mut values = Vec::with_capacity(count as usize);
278    for _ in 0..count {
279        let len = cur.read_i32()?;
280        if len == -1 {
281            values.push(None);
282            continue;
283        }
284        if len < 0 {
285            return Err(ProtocolError::InvalidField("negative data length"));
286        }
287        let bytes = cur.read_bytes(len as usize)?.to_vec();
288        values.push(Some(bytes));
289    }
290    Ok(BackendMessage::DataRow(values))
291}
292
293fn parse_command_complete(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
294    let tag = cur.read_cstring()?;
295    Ok(BackendMessage::CommandComplete(tag))
296}
297
298fn parse_parameter_description(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
299    let count = cur.read_i16()?;
300    if count < 0 {
301        return Err(ProtocolError::InvalidField("negative parameter count"));
302    }
303    let mut oids = Vec::with_capacity(count as usize);
304    for _ in 0..count {
305        oids.push(cur.read_u32()?);
306    }
307    Ok(BackendMessage::ParameterDescription(oids))
308}
309
310fn parse_copy_in_response(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
311    let format = cur.read_i8()?;
312    let column_formats = read_column_formats(cur)?;
313    Ok(BackendMessage::CopyInResponse {
314        format,
315        column_formats,
316    })
317}
318
319fn parse_copy_out_response(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
320    let format = cur.read_i8()?;
321    let column_formats = read_column_formats(cur)?;
322    Ok(BackendMessage::CopyOutResponse {
323        format,
324        column_formats,
325    })
326}
327
328fn parse_copy_both_response(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
329    let format = cur.read_i8()?;
330    let column_formats = read_column_formats(cur)?;
331    Ok(BackendMessage::CopyBothResponse {
332        format,
333        column_formats,
334    })
335}
336
337fn read_column_formats(cur: &mut Cursor<'_>) -> Result<Vec<i16>, ProtocolError> {
338    let count = cur.read_i16()?;
339    if count < 0 {
340        return Err(ProtocolError::InvalidField("negative format count"));
341    }
342    let mut formats = Vec::with_capacity(count as usize);
343    for _ in 0..count {
344        formats.push(cur.read_i16()?);
345    }
346    Ok(formats)
347}
348
349fn parse_notification_response(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
350    let process_id = cur.read_i32()?;
351    let channel = cur.read_cstring()?;
352    let payload = cur.read_cstring()?;
353    Ok(BackendMessage::NotificationResponse {
354        process_id,
355        channel,
356        payload,
357    })
358}
359
360fn parse_function_call_response(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
361    let len = cur.read_i32()?;
362    if len == -1 {
363        return Ok(BackendMessage::FunctionCallResponse(None));
364    }
365    if len < 0 {
366        return Err(ProtocolError::InvalidField("negative function length"));
367    }
368    let bytes = cur.read_bytes(len as usize)?.to_vec();
369    Ok(BackendMessage::FunctionCallResponse(Some(bytes)))
370}
371
372fn parse_negotiate_protocol_version(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
373    let newest_minor = cur.read_i32()?;
374    let count = cur.read_i32()?;
375    if count < 0 {
376        return Err(ProtocolError::InvalidField(
377            "negative protocol option count",
378        ));
379    }
380    let mut unrecognized = Vec::with_capacity(count as usize);
381    for _ in 0..count {
382        unrecognized.push(cur.read_cstring()?);
383    }
384    Ok(BackendMessage::NegotiateProtocolVersion {
385        newest_minor,
386        unrecognized,
387    })
388}
389
390fn parse_error_response(
391    cur: &mut Cursor<'_>,
392    is_error: bool,
393) -> Result<BackendMessage, ProtocolError> {
394    let mut fields = ErrorFields::default();
395    loop {
396        let code = cur.read_u8()?;
397        if code == 0 {
398            break;
399        }
400        let value = cur.read_cstring()?;
401        match code {
402            b'S' => fields.severity = value,
403            b'V' => fields.severity_localized = Some(value),
404            b'C' => fields.code = value,
405            b'M' => fields.message = value,
406            b'D' => fields.detail = Some(value),
407            b'H' => fields.hint = Some(value),
408            b'P' => fields.position = value.parse().ok(),
409            b'p' => fields.internal_position = value.parse().ok(),
410            b'q' => fields.internal_query = Some(value),
411            b'W' => fields.where_ = Some(value),
412            b's' => fields.schema = Some(value),
413            b't' => fields.table = Some(value),
414            b'c' => fields.column = Some(value),
415            b'd' => fields.data_type = Some(value),
416            b'n' => fields.constraint = Some(value),
417            b'F' => fields.file = Some(value),
418            b'L' => fields.line = value.parse().ok(),
419            b'R' => fields.routine = Some(value),
420            _ => {
421                // Ignore unknown fields.
422            }
423        }
424    }
425
426    if is_error {
427        Ok(BackendMessage::ErrorResponse(fields))
428    } else {
429        Ok(BackendMessage::NoticeResponse(fields))
430    }
431}
432
433#[derive(Debug)]
434struct Cursor<'a> {
435    buf: &'a [u8],
436    pos: usize,
437}
438
439impl<'a> Cursor<'a> {
440    fn new(buf: &'a [u8]) -> Self {
441        Self { buf, pos: 0 }
442    }
443
444    fn remaining(&self) -> usize {
445        self.buf.len().saturating_sub(self.pos)
446    }
447
448    fn read_u8(&mut self) -> Result<u8, ProtocolError> {
449        if self.remaining() < 1 {
450            return Err(ProtocolError::UnexpectedEof);
451        }
452        let b = self.buf[self.pos];
453        self.pos += 1;
454        Ok(b)
455    }
456
457    fn read_i8(&mut self) -> Result<i8, ProtocolError> {
458        let b = self.read_u8()?;
459        Ok(b as i8)
460    }
461
462    fn read_i16(&mut self) -> Result<i16, ProtocolError> {
463        let bytes = self.read_bytes(2)?;
464        Ok(i16::from_be_bytes([bytes[0], bytes[1]]))
465    }
466
467    fn read_u32(&mut self) -> Result<u32, ProtocolError> {
468        let bytes = self.read_bytes(4)?;
469        Ok(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
470    }
471
472    fn read_i32(&mut self) -> Result<i32, ProtocolError> {
473        let bytes = self.read_bytes(4)?;
474        Ok(i32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
475    }
476
477    fn read_bytes(&mut self, n: usize) -> Result<&'a [u8], ProtocolError> {
478        if self.remaining() < n {
479            return Err(ProtocolError::UnexpectedEof);
480        }
481        let start = self.pos;
482        let end = self.pos + n;
483        self.pos = end;
484        Ok(&self.buf[start..end])
485    }
486
487    fn read_cstring(&mut self) -> Result<String, ProtocolError> {
488        let start = self.pos;
489        while self.pos < self.buf.len() && self.buf[self.pos] != 0 {
490            self.pos += 1;
491        }
492        if self.pos >= self.buf.len() {
493            return Err(ProtocolError::UnexpectedEof);
494        }
495        let bytes = self.buf[start..self.pos].to_vec();
496        self.pos += 1; // consume null terminator
497        Ok(String::from_utf8(bytes)?)
498    }
499
500    fn take_remaining(&mut self) -> Vec<u8> {
501        let remaining = self.buf[self.pos..].to_vec();
502        self.pos = self.buf.len();
503        remaining
504    }
505}
506
507#[cfg(test)]
508mod tests {
509    use super::*;
510
511    #[allow(clippy::cast_possible_truncation)]
512    fn build_message(ty: u8, payload: &[u8]) -> Vec<u8> {
513        let mut buf = Vec::new();
514        buf.push(ty);
515        let len = (payload.len() + 4) as i32;
516        buf.extend_from_slice(&len.to_be_bytes());
517        buf.extend_from_slice(payload);
518        buf
519    }
520
521    #[test]
522    fn parse_auth_ok() {
523        let mut payload = Vec::new();
524        payload.extend_from_slice(&auth_type::OK.to_be_bytes());
525        let msg = build_message(backend_type::AUTHENTICATION, &payload);
526        let decoded = MessageReader::parse_message(&msg).unwrap();
527        assert!(matches!(decoded, BackendMessage::AuthenticationOk));
528    }
529
530    #[test]
531    fn parse_ready_for_query() {
532        let payload = [TransactionStatus::Idle.as_byte()];
533        let msg = build_message(backend_type::READY_FOR_QUERY, &payload);
534        let decoded = MessageReader::parse_message(&msg).unwrap();
535        assert!(matches!(
536            decoded,
537            BackendMessage::ReadyForQuery(TransactionStatus::Idle)
538        ));
539    }
540
541    #[test]
542    fn parse_error_response() {
543        let mut payload = Vec::new();
544        payload.push(b'S');
545        payload.extend_from_slice(b"ERROR\0");
546        payload.push(b'C');
547        payload.extend_from_slice(b"12345\0");
548        payload.push(b'M');
549        payload.extend_from_slice(b"bad\0");
550        payload.push(0);
551
552        let msg = build_message(backend_type::ERROR_RESPONSE, &payload);
553        let decoded = MessageReader::parse_message(&msg).unwrap();
554        match decoded {
555            BackendMessage::ErrorResponse(fields) => {
556                assert_eq!(fields.severity, "ERROR");
557                assert_eq!(fields.code, "12345");
558                assert_eq!(fields.message, "bad");
559            }
560            _ => panic!("unexpected message"),
561        }
562    }
563
564    #[test]
565    fn parse_data_row() {
566        let mut payload = Vec::new();
567        payload.extend_from_slice(&(2_i16).to_be_bytes());
568        payload.extend_from_slice(&(3_i32).to_be_bytes());
569        payload.extend_from_slice(b"foo");
570        payload.extend_from_slice(&(-1_i32).to_be_bytes());
571
572        let msg = build_message(backend_type::DATA_ROW, &payload);
573        let decoded = MessageReader::parse_message(&msg).unwrap();
574        match decoded {
575            BackendMessage::DataRow(values) => {
576                assert_eq!(values.len(), 2);
577                assert_eq!(values[0].as_deref(), Some(b"foo".as_slice()));
578                assert!(values[1].is_none());
579            }
580            _ => panic!("unexpected message"),
581        }
582    }
583
584    #[test]
585    fn reader_buffers_partial_frames() {
586        let payload = [TransactionStatus::Idle.as_byte()];
587        let msg = build_message(backend_type::READY_FOR_QUERY, &payload);
588        let (left, right) = msg.split_at(3);
589
590        let mut reader = MessageReader::new();
591        let first = reader.feed(left).unwrap();
592        assert!(first.is_empty());
593
594        let second = reader.feed(right).unwrap();
595        assert_eq!(second.len(), 1);
596    }
597
598    #[test]
599    fn parse_row_description_negative_count_rejected() {
600        // ROW_DESCRIPTION with negative field count (-1)
601        let payload = (-1_i16).to_be_bytes();
602        let msg = build_message(backend_type::ROW_DESCRIPTION, &payload);
603        let result = MessageReader::parse_message(&msg);
604        assert!(matches!(result, Err(ProtocolError::InvalidField(_))));
605    }
606
607    #[test]
608    fn parse_data_row_negative_count_rejected() {
609        // DATA_ROW with negative column count (-1)
610        let payload = (-1_i16).to_be_bytes();
611        let msg = build_message(backend_type::DATA_ROW, &payload);
612        let result = MessageReader::parse_message(&msg);
613        assert!(matches!(result, Err(ProtocolError::InvalidField(_))));
614    }
615
616    #[test]
617    fn parse_parameter_description_negative_count_rejected() {
618        // PARAMETER_DESCRIPTION with negative parameter count (-1)
619        let payload = (-1_i16).to_be_bytes();
620        let msg = build_message(backend_type::PARAMETER_DESCRIPTION, &payload);
621        let result = MessageReader::parse_message(&msg);
622        assert!(matches!(result, Err(ProtocolError::InvalidField(_))));
623    }
624}