Skip to main content

sentinel_driver/protocol/
backend.rs

1use bytes::Bytes;
2
3use crate::error::{Error, Result};
4
5/// A decoded backend (server → client) message.
6#[derive(Debug)]
7pub enum BackendMessage {
8    AuthenticationOk,
9    AuthenticationCleartextPassword,
10    AuthenticationMd5Password {
11        salt: [u8; 4],
12    },
13    AuthenticationSasl {
14        mechanisms: Vec<String>,
15    },
16    AuthenticationSaslContinue {
17        data: Vec<u8>,
18    },
19    AuthenticationSaslFinal {
20        data: Vec<u8>,
21    },
22
23    BackendKeyData {
24        process_id: i32,
25        secret_key: i32,
26    },
27
28    ParameterStatus {
29        name: String,
30        value: String,
31    },
32
33    ReadyForQuery {
34        transaction_status: TransactionStatus,
35    },
36
37    RowDescription {
38        fields: Vec<FieldDescription>,
39    },
40
41    DataRow {
42        columns: DataRowColumns,
43    },
44
45    CommandComplete {
46        tag: String,
47    },
48
49    EmptyQueryResponse,
50
51    ErrorResponse {
52        fields: ErrorFields,
53    },
54
55    NoticeResponse {
56        fields: ErrorFields,
57    },
58
59    ParseComplete,
60    BindComplete,
61    CloseComplete,
62    NoData,
63    PortalSuspended,
64
65    ParameterDescription {
66        oids: Vec<u32>,
67    },
68
69    CopyInResponse {
70        format: CopyFormat,
71        column_formats: Vec<i16>,
72    },
73    CopyOutResponse {
74        format: CopyFormat,
75        column_formats: Vec<i16>,
76    },
77    CopyData {
78        data: Bytes,
79    },
80    CopyDone,
81
82    NotificationResponse {
83        process_id: i32,
84        channel: String,
85        payload: String,
86    },
87}
88
89/// Transaction status indicator from ReadyForQuery.
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
91pub enum TransactionStatus {
92    /// Not in a transaction block.
93    Idle,
94    /// In a transaction block.
95    InTransaction,
96    /// In a failed transaction block.
97    Failed,
98}
99
100/// COPY format.
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
102pub enum CopyFormat {
103    Text,
104    Binary,
105}
106
107/// Description of a single column in RowDescription.
108#[derive(Debug, Clone)]
109pub struct FieldDescription {
110    pub name: String,
111    pub table_oid: u32,
112    pub column_id: i16,
113    pub type_oid: u32,
114    pub type_size: i16,
115    pub type_modifier: i32,
116    pub format: i16,
117}
118
119/// Error/Notice response fields.
120#[derive(Debug, Clone)]
121pub struct ErrorFields {
122    pub severity: String,
123    pub code: String,
124    pub message: String,
125    pub detail: Option<String>,
126    pub hint: Option<String>,
127    pub position: Option<u32>,
128    pub internal_position: Option<u32>,
129    pub internal_query: Option<String>,
130    pub where_: Option<String>,
131    pub schema: Option<String>,
132    pub table: Option<String>,
133    pub column: Option<String>,
134    pub data_type: Option<String>,
135    pub constraint: Option<String>,
136    pub file: Option<String>,
137    pub line: Option<u32>,
138    pub routine: Option<String>,
139}
140
141/// Zero-copy column data from a DataRow message.
142///
143/// Stores the raw buffer and column offsets for deferred decoding.
144#[derive(Debug)]
145pub struct DataRowColumns {
146    buf: Bytes,
147    /// Each entry is (offset, length). length == -1 means NULL.
148    columns: Vec<(usize, i32)>,
149}
150
151impl DataRowColumns {
152    /// Number of columns.
153    pub fn len(&self) -> usize {
154        self.columns.len()
155    }
156
157    pub fn is_empty(&self) -> bool {
158        self.columns.is_empty()
159    }
160
161    /// Get raw bytes for column at `idx`. Returns `None` for NULL.
162    pub fn get(&self, idx: usize) -> Option<Bytes> {
163        let &(offset, len) = self.columns.get(idx)?;
164        if len < 0 {
165            None // NULL
166        } else {
167            Some(self.buf.slice(offset..offset + len as usize))
168        }
169    }
170
171    /// Returns `true` if column `idx` is NULL.
172    pub fn is_null(&self, idx: usize) -> bool {
173        self.columns.get(idx).map_or(true, |&(_, len)| len < 0)
174    }
175}
176
177/// Decode a single backend message from a raw frame.
178///
179/// `msg_type` is the first byte, `body` is the payload (after the length).
180pub fn decode(msg_type: u8, body: Bytes) -> Result<BackendMessage> {
181    match msg_type {
182        b'R' => decode_auth(&body),
183        b'K' => decode_backend_key_data(&body),
184        b'S' => decode_parameter_status(&body),
185        b'Z' => decode_ready_for_query(&body),
186        b'T' => decode_row_description(&body),
187        b'D' => decode_data_row(body),
188        b'C' => decode_command_complete(&body),
189        b'I' => Ok(BackendMessage::EmptyQueryResponse),
190        b'E' => decode_error_response(&body),
191        b'N' => decode_notice_response(&body),
192        b'1' => Ok(BackendMessage::ParseComplete),
193        b'2' => Ok(BackendMessage::BindComplete),
194        b'3' => Ok(BackendMessage::CloseComplete),
195        b'n' => Ok(BackendMessage::NoData),
196        b's' => Ok(BackendMessage::PortalSuspended),
197        b't' => decode_parameter_description(&body),
198        b'G' => decode_copy_in_response(&body),
199        b'H' => decode_copy_out_response(&body),
200        b'd' => Ok(BackendMessage::CopyData { data: body }),
201        b'c' => Ok(BackendMessage::CopyDone),
202        b'A' => decode_notification(&body),
203        _ => Err(Error::protocol(format!(
204            "unknown message type: 0x{msg_type:02x}"
205        ))),
206    }
207}
208
209// ── Decoders ─────────────────────────────────────────
210
211fn decode_auth(body: &[u8]) -> Result<BackendMessage> {
212    if body.len() < 4 {
213        return Err(Error::protocol("auth message too short"));
214    }
215    let auth_type = read_i32(body, 0);
216
217    match auth_type {
218        0 => Ok(BackendMessage::AuthenticationOk),
219        3 => Ok(BackendMessage::AuthenticationCleartextPassword),
220        5 => {
221            if body.len() < 8 {
222                return Err(Error::protocol("MD5 auth message too short"));
223            }
224            let mut salt = [0u8; 4];
225            salt.copy_from_slice(&body[4..8]);
226            Ok(BackendMessage::AuthenticationMd5Password { salt })
227        }
228        10 => {
229            // SASL — parse null-separated mechanism list
230            let mut mechanisms = Vec::new();
231            let mut pos = 4;
232            loop {
233                if pos >= body.len() {
234                    break;
235                }
236                let s = read_cstr(body, &mut pos)?;
237                if s.is_empty() {
238                    break;
239                }
240                mechanisms.push(s);
241            }
242            Ok(BackendMessage::AuthenticationSasl { mechanisms })
243        }
244        11 => Ok(BackendMessage::AuthenticationSaslContinue {
245            data: body[4..].to_vec(),
246        }),
247        12 => Ok(BackendMessage::AuthenticationSaslFinal {
248            data: body[4..].to_vec(),
249        }),
250        _ => Err(Error::protocol(format!(
251            "unsupported auth type: {auth_type}"
252        ))),
253    }
254}
255
256fn decode_backend_key_data(body: &[u8]) -> Result<BackendMessage> {
257    if body.len() < 8 {
258        return Err(Error::protocol("BackendKeyData too short"));
259    }
260    Ok(BackendMessage::BackendKeyData {
261        process_id: read_i32(body, 0),
262        secret_key: read_i32(body, 4),
263    })
264}
265
266fn decode_parameter_status(body: &[u8]) -> Result<BackendMessage> {
267    let mut pos = 0;
268    let name = read_cstr(body, &mut pos)?;
269    let value = read_cstr(body, &mut pos)?;
270    Ok(BackendMessage::ParameterStatus { name, value })
271}
272
273fn decode_ready_for_query(body: &[u8]) -> Result<BackendMessage> {
274    if body.is_empty() {
275        return Err(Error::protocol("ReadyForQuery empty"));
276    }
277    let status = match body[0] {
278        b'I' => TransactionStatus::Idle,
279        b'T' => TransactionStatus::InTransaction,
280        b'E' => TransactionStatus::Failed,
281        s => return Err(Error::protocol(format!("unknown transaction status: {s}"))),
282    };
283    Ok(BackendMessage::ReadyForQuery {
284        transaction_status: status,
285    })
286}
287
288fn decode_row_description(body: &[u8]) -> Result<BackendMessage> {
289    if body.len() < 2 {
290        return Err(Error::protocol("RowDescription too short"));
291    }
292    let field_count = read_i16(body, 0) as usize;
293    let mut fields = Vec::with_capacity(field_count);
294    let mut pos = 2;
295
296    for _ in 0..field_count {
297        let name = read_cstr(body, &mut pos)?;
298
299        if pos + 18 > body.len() {
300            return Err(Error::protocol("RowDescription field truncated"));
301        }
302
303        let table_oid = read_u32(body, pos);
304        let column_id = read_i16(body, pos + 4);
305        let type_oid = read_u32(body, pos + 6);
306        let type_size = read_i16(body, pos + 10);
307        let type_modifier = read_i32(body, pos + 12);
308        let format = read_i16(body, pos + 16);
309        pos += 18;
310
311        fields.push(FieldDescription {
312            name,
313            table_oid,
314            column_id,
315            type_oid,
316            type_size,
317            type_modifier,
318            format,
319        });
320    }
321
322    Ok(BackendMessage::RowDescription { fields })
323}
324
325fn decode_data_row(body: Bytes) -> Result<BackendMessage> {
326    if body.len() < 2 {
327        return Err(Error::protocol("DataRow too short"));
328    }
329    let col_count = read_i16(&body, 0) as usize;
330    let mut columns = Vec::with_capacity(col_count);
331    let mut pos = 2;
332
333    for _ in 0..col_count {
334        if pos + 4 > body.len() {
335            return Err(Error::protocol("DataRow column truncated"));
336        }
337        let len = read_i32(&body, pos);
338        pos += 4;
339
340        if len < 0 {
341            columns.push((0, -1)); // NULL
342        } else {
343            let len_usize = len as usize;
344            if pos + len_usize > body.len() {
345                return Err(Error::protocol("DataRow column data truncated"));
346            }
347            columns.push((pos, len));
348            pos += len_usize;
349        }
350    }
351
352    Ok(BackendMessage::DataRow {
353        columns: DataRowColumns { buf: body, columns },
354    })
355}
356
357fn decode_command_complete(body: &[u8]) -> Result<BackendMessage> {
358    let mut pos = 0;
359    let tag = read_cstr(body, &mut pos)?;
360    Ok(BackendMessage::CommandComplete { tag })
361}
362
363fn decode_error_notice_fields(body: &[u8]) -> Result<ErrorFields> {
364    let mut severity = String::new();
365    let mut code = String::new();
366    let mut message = String::new();
367    let mut detail = None;
368    let mut hint = None;
369    let mut position = None;
370    let mut internal_position = None;
371    let mut internal_query = None;
372    let mut where_ = None;
373    let mut schema = None;
374    let mut table = None;
375    let mut column = None;
376    let mut data_type = None;
377    let mut constraint = None;
378    let mut file = None;
379    let mut line = None;
380    let mut routine = None;
381
382    let mut pos = 0;
383    loop {
384        if pos >= body.len() {
385            break;
386        }
387        let field_type = body[pos];
388        pos += 1;
389        if field_type == 0 {
390            break;
391        }
392        let value = read_cstr(body, &mut pos)?;
393
394        match field_type {
395            b'S' => severity = value,
396            b'C' => code = value,
397            b'M' => message = value,
398            b'D' => detail = Some(value),
399            b'H' => hint = Some(value),
400            b'P' => position = value.parse().ok(),
401            b'p' => internal_position = value.parse().ok(),
402            b'q' => internal_query = Some(value),
403            b'W' => where_ = Some(value),
404            b's' => schema = Some(value),
405            b't' => table = Some(value),
406            b'c' => column = Some(value),
407            b'd' => data_type = Some(value),
408            b'n' => constraint = Some(value),
409            b'F' => file = Some(value),
410            b'L' => line = value.parse().ok(),
411            b'R' => routine = Some(value),
412            _ => {} // ignore unknown fields
413        }
414    }
415
416    Ok(ErrorFields {
417        severity,
418        code,
419        message,
420        detail,
421        hint,
422        position,
423        internal_position,
424        internal_query,
425        where_,
426        schema,
427        table,
428        column,
429        data_type,
430        constraint,
431        file,
432        line,
433        routine,
434    })
435}
436
437fn decode_error_response(body: &[u8]) -> Result<BackendMessage> {
438    let fields = decode_error_notice_fields(body)?;
439    Ok(BackendMessage::ErrorResponse { fields })
440}
441
442fn decode_notice_response(body: &[u8]) -> Result<BackendMessage> {
443    let fields = decode_error_notice_fields(body)?;
444    Ok(BackendMessage::NoticeResponse { fields })
445}
446
447fn decode_parameter_description(body: &[u8]) -> Result<BackendMessage> {
448    if body.len() < 2 {
449        return Err(Error::protocol("ParameterDescription too short"));
450    }
451    let count = read_i16(body, 0) as usize;
452    let mut oids = Vec::with_capacity(count);
453    let mut pos = 2;
454
455    for _ in 0..count {
456        if pos + 4 > body.len() {
457            return Err(Error::protocol("ParameterDescription truncated"));
458        }
459        oids.push(read_u32(body, pos));
460        pos += 4;
461    }
462
463    Ok(BackendMessage::ParameterDescription { oids })
464}
465
466fn decode_copy_response(body: &[u8]) -> Result<(CopyFormat, Vec<i16>)> {
467    if body.len() < 3 {
468        return Err(Error::protocol("CopyResponse too short"));
469    }
470    let format = match body[0] {
471        0 => CopyFormat::Text,
472        1 => CopyFormat::Binary,
473        f => return Err(Error::protocol(format!("unknown copy format: {f}"))),
474    };
475    let col_count = read_i16(body, 1) as usize;
476    let mut column_formats = Vec::with_capacity(col_count);
477    let mut pos = 3;
478
479    for _ in 0..col_count {
480        if pos + 2 > body.len() {
481            return Err(Error::protocol("CopyResponse column formats truncated"));
482        }
483        column_formats.push(read_i16(body, pos));
484        pos += 2;
485    }
486
487    Ok((format, column_formats))
488}
489
490fn decode_copy_in_response(body: &[u8]) -> Result<BackendMessage> {
491    let (format, column_formats) = decode_copy_response(body)?;
492    Ok(BackendMessage::CopyInResponse {
493        format,
494        column_formats,
495    })
496}
497
498fn decode_copy_out_response(body: &[u8]) -> Result<BackendMessage> {
499    let (format, column_formats) = decode_copy_response(body)?;
500    Ok(BackendMessage::CopyOutResponse {
501        format,
502        column_formats,
503    })
504}
505
506fn decode_notification(body: &[u8]) -> Result<BackendMessage> {
507    if body.len() < 4 {
508        return Err(Error::protocol("NotificationResponse too short"));
509    }
510    let process_id = read_i32(body, 0);
511    let mut pos = 4;
512    let channel = read_cstr(body, &mut pos)?;
513    let payload = read_cstr(body, &mut pos)?;
514
515    Ok(BackendMessage::NotificationResponse {
516        process_id,
517        channel,
518        payload,
519    })
520}
521
522// ── Read helpers ─────────────────────────────────────
523
524fn read_i32(buf: &[u8], offset: usize) -> i32 {
525    i32::from_be_bytes([
526        buf[offset],
527        buf[offset + 1],
528        buf[offset + 2],
529        buf[offset + 3],
530    ])
531}
532
533fn read_u32(buf: &[u8], offset: usize) -> u32 {
534    u32::from_be_bytes([
535        buf[offset],
536        buf[offset + 1],
537        buf[offset + 2],
538        buf[offset + 3],
539    ])
540}
541
542fn read_i16(buf: &[u8], offset: usize) -> i16 {
543    i16::from_be_bytes([buf[offset], buf[offset + 1]])
544}
545
546/// Read a null-terminated string starting at `pos`, advancing `pos` past the null.
547fn read_cstr(buf: &[u8], pos: &mut usize) -> Result<String> {
548    let start = *pos;
549    let null_pos = buf[start..]
550        .iter()
551        .position(|&b| b == 0)
552        .ok_or_else(|| Error::protocol("missing null terminator"))?;
553
554    let s = std::str::from_utf8(&buf[start..start + null_pos])
555        .map_err(|e| Error::protocol(format!("invalid UTF-8 in message: {e}")))?
556        .to_string();
557
558    *pos = start + null_pos + 1;
559    Ok(s)
560}