zero_postgres/protocol/frontend/
extended.rs

1//! Extended query protocol messages.
2
3use crate::conversion::ToParams;
4use crate::error::Result;
5use crate::protocol::codec::MessageBuilder;
6use crate::protocol::types::{FormatCode, Oid, preferred_format};
7
8/// Write a Parse message to create a prepared statement.
9///
10/// - `name`: Statement name (empty string for unnamed statement)
11/// - `query`: SQL query with $1, $2, ... placeholders
12/// - `param_oids`: Parameter type OIDs (0 = let server infer)
13pub fn write_parse(buf: &mut Vec<u8>, name: &str, query: &str, param_oids: &[Oid]) {
14    log::debug!("PARSE {query}");
15    let mut msg = MessageBuilder::new(buf, super::msg_type::PARSE);
16    msg.write_cstr(name);
17    msg.write_cstr(query);
18    msg.write_i16(param_oids.len() as i16);
19    for &oid in param_oids {
20        msg.write_i32(oid as i32);
21    }
22    msg.finish();
23}
24
25/// Write a Bind message to create a portal from a prepared statement.
26///
27/// - `portal`: Portal name (empty string for unnamed portal)
28/// - `statement_name`: Prepared statement name
29/// - `params`: Parameter values (tuple of ToValue types)
30/// - `target_oids`: Target OIDs for encoding parameters
31///
32/// Uses per-parameter format codes based on `preferred_format()`:
33/// - NUMERIC uses text format (0)
34/// - All other types use binary format (1)
35pub fn write_bind<P: ToParams>(
36    buf: &mut Vec<u8>,
37    portal: &str,
38    statement_name: &str,
39    params: &P,
40    target_oids: &[Oid],
41) -> Result<()> {
42    log::debug!(
43        "BIND {} {}",
44        if statement_name.is_empty() {
45            "<unnamed statement>"
46        } else {
47            statement_name
48        },
49        if portal.is_empty() {
50            "<unnamed portal>"
51        } else {
52            portal
53        }
54    );
55    let mut msg = MessageBuilder::new(buf, super::msg_type::BIND);
56
57    // Portal and statement names
58    msg.write_cstr(portal);
59    msg.write_cstr(statement_name);
60
61    // Parameter format codes: one per parameter
62    let param_count = params.param_count();
63    msg.write_i16(param_count as i16);
64    for &oid in target_oids {
65        msg.write_i16(preferred_format(oid) as i16);
66    }
67
68    // Parameter values (count + length-prefixed data)
69    msg.write_i16(param_count as i16);
70    params.encode(target_oids, msg.buf())?;
71
72    // Result format codes: 1 code that applies to all columns (binary)
73    msg.write_i16(1);
74    msg.write_i16(FormatCode::Binary as i16);
75
76    msg.finish();
77    Ok(())
78}
79
80/// Write an Execute message to run a portal.
81///
82/// - `portal`: Portal name
83/// - `max_rows`: Maximum number of rows to return (0 = unlimited)
84pub fn write_execute(buf: &mut Vec<u8>, portal: &str, max_rows: u32) {
85    log::debug!(
86        "EXECUTE {} LIMIT {max_rows}",
87        if portal.is_empty() {
88            "<unnamed portal>"
89        } else {
90            portal
91        }
92    );
93    let mut msg = MessageBuilder::new(buf, super::msg_type::EXECUTE);
94    msg.write_cstr(portal);
95    msg.write_i32(max_rows as i32);
96    msg.finish();
97}
98
99/// Write a Describe message to get metadata.
100///
101/// - `describe_type`: 'S' for statement, 'P' for portal
102/// - `name`: Statement or portal name
103pub fn write_describe(buf: &mut Vec<u8>, describe_type: u8, name: &str) {
104    log::debug!("DESCRIBE({}) {name}", describe_type as char);
105    let mut msg = MessageBuilder::new(buf, super::msg_type::DESCRIBE);
106    msg.write_u8(describe_type);
107    msg.write_cstr(name);
108    msg.finish();
109}
110
111/// Write a Describe message for a statement.
112pub fn write_describe_statement(buf: &mut Vec<u8>, name: &str) {
113    write_describe(buf, b'S', name);
114}
115
116/// Write a Describe message for a portal.
117pub fn write_describe_portal(buf: &mut Vec<u8>, name: &str) {
118    write_describe(buf, b'P', name);
119}
120
121/// Write a Close message to release a statement or portal.
122///
123/// - `close_type`: 'S' for statement, 'P' for portal
124/// - `name`: Statement or portal name
125pub fn write_close(buf: &mut Vec<u8>, close_type: u8, name: &str) {
126    log::debug!("CLOSE({}) {name}", close_type as char);
127    let mut msg = MessageBuilder::new(buf, super::msg_type::CLOSE);
128    msg.write_u8(close_type);
129    msg.write_cstr(name);
130    msg.finish();
131}
132
133/// Write a Close message for a statement.
134pub fn write_close_statement(buf: &mut Vec<u8>, name: &str) {
135    write_close(buf, b'S', name);
136}
137
138/// Write a Close message for a portal.
139pub fn write_close_portal(buf: &mut Vec<u8>, name: &str) {
140    write_close(buf, b'P', name);
141}
142
143/// Write a Sync message.
144///
145/// This ends an extended query sequence and causes:
146/// - Implicit COMMIT if successful and not in explicit transaction
147/// - Implicit ROLLBACK if failed and not in explicit transaction
148/// - Server responds with ReadyForQuery
149pub fn write_sync(buf: &mut Vec<u8>) {
150    log::debug!("SYNC");
151    let msg = MessageBuilder::new(buf, super::msg_type::SYNC);
152    msg.finish();
153}
154
155/// Write a Flush message.
156///
157/// Forces the server to send all pending responses without waiting for Sync.
158/// Useful for pipelining when you need intermediate results.
159pub fn write_flush(buf: &mut Vec<u8>) {
160    log::debug!("FLUSH");
161    let msg = MessageBuilder::new(buf, super::msg_type::FLUSH);
162    msg.finish();
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[test]
170    fn test_parse() {
171        let mut buf = Vec::new();
172        write_parse(&mut buf, "stmt1", "SELECT $1::int", &[0]);
173
174        assert_eq!(buf[0], b'P');
175
176        // Verify length field
177        let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]);
178        assert_eq!(len as usize, buf.len() - 1);
179    }
180
181    #[test]
182    fn test_sync() {
183        let mut buf = Vec::new();
184        write_sync(&mut buf);
185
186        assert_eq!(buf.len(), 5);
187        assert_eq!(buf[0], b'S');
188        assert_eq!(&buf[1..5], &4_i32.to_be_bytes());
189    }
190
191    #[test]
192    fn test_flush() {
193        let mut buf = Vec::new();
194        write_flush(&mut buf);
195
196        assert_eq!(buf.len(), 5);
197        assert_eq!(buf[0], b'H');
198        assert_eq!(&buf[1..5], &4_i32.to_be_bytes());
199    }
200
201    #[test]
202    fn test_execute() {
203        let mut buf = Vec::new();
204        write_execute(&mut buf, "", 0);
205
206        assert_eq!(buf[0], b'E');
207        // Length: 4 + 1 (empty string + null) + 4 (max_rows) = 9
208        let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]);
209        assert_eq!(len, 9);
210    }
211}