Skip to main content

voltdb_client_rust/
protocol.rs

1//! Shared protocol handling for sync and async connections.
2//!
3//! This module contains common code for VoltDB wire protocol handling,
4//! including authentication handshake and response parsing.
5
6use std::io::Read;
7use std::net::Ipv4Addr;
8use std::str::from_utf8;
9
10use bytebuffer::ByteBuffer;
11use sha2::{Digest, Sha256};
12
13use crate::encode::VoltError;
14use crate::node::ConnInfo;
15
16/// Protocol version byte
17pub const PROTOCOL_VERSION: u8 = 1;
18
19/// Ping handle constant - used for keep-alive messages
20pub const PING_HANDLE: i64 = 1 << (63 - 1);
21
22/// Build authentication message for VoltDB connection.
23///
24/// Returns the serialized authentication message as bytes.
25pub fn build_auth_message(user: Option<&str>, pass: Option<&str>) -> Result<Vec<u8>, VoltError> {
26    let mut buffer = ByteBuffer::new();
27    let version = [PROTOCOL_VERSION; 1];
28
29    // Message length placeholder (will be filled later)
30    buffer.write_u32(0);
31    // Protocol version
32    buffer.write_bytes(&version);
33    buffer.write_bytes(&version);
34    // Database name
35    buffer.write_string("database");
36
37    // Username
38    match user {
39        None => buffer.write_string(""),
40        Some(u) => buffer.write_string(u),
41    }
42
43    // Password hash (SHA-256)
44    let password_bytes = pass.map(|p| p.as_bytes()).unwrap_or(&[]);
45    let mut hasher: Sha256 = Sha256::new();
46    Digest::update(&mut hasher, password_bytes);
47    buffer.write_bytes(&hasher.finalize());
48
49    // Update message length
50    buffer.set_wpos(0);
51    buffer.write_u32((buffer.len() - 4) as u32);
52
53    Ok(buffer.into_vec())
54}
55
56/// Parse authentication response from server.
57///
58/// Returns connection info on success, or VoltError::AuthFailed on failure.
59pub fn parse_auth_response(data: &[u8]) -> Result<ConnInfo, VoltError> {
60    let mut res = ByteBuffer::from_bytes(data);
61
62    let _version = res.read_u8()?;
63    let auth = res.read_u8()?;
64
65    if auth != 0 {
66        return Err(VoltError::AuthFailed);
67    }
68
69    let host_id = res.read_i32()?;
70    let connection = res.read_i64()?;
71    let _ = res.read_i64()?; // timestamp
72    let leader = res.read_i32()?;
73    let bs = (leader as u32).to_be_bytes();
74    let leader_addr = Ipv4Addr::from(bs);
75
76    let length = res.read_i32()?;
77    let mut build = vec![0; length as usize];
78    res.read_exact(&mut build)?;
79    let b = from_utf8(&build)?;
80
81    Ok(ConnInfo {
82        host_id,
83        connection,
84        leader_addr,
85        build: String::from(b),
86    })
87}
88
89/// Read a length-prefixed message from a stream.
90///
91/// Returns the message payload (without the length prefix).
92pub fn read_message<R: Read>(reader: &mut R) -> Result<Vec<u8>, VoltError> {
93    use byteorder::{BigEndian, ReadBytesExt};
94
95    let len = reader.read_u32::<BigEndian>()?;
96    if len == 0 {
97        return Ok(Vec::new());
98    }
99
100    let mut data = vec![0u8; len as usize];
101    reader.read_exact(&mut data)?;
102    Ok(data)
103}
104
105/// Parse the handle from a response message.
106///
107/// Assumes the first byte is status and the next 8 bytes are the handle.
108pub fn parse_response_handle(data: &[u8]) -> Result<i64, VoltError> {
109    if data.len() < 9 {
110        return Err(VoltError::Other("Response too short".to_string()));
111    }
112
113    let mut buffer = ByteBuffer::from_bytes(&data[1..9]);
114    Ok(buffer.read_i64()?)
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn test_build_auth_message_no_credentials() {
123        let msg = build_auth_message(None, None).unwrap();
124        assert!(!msg.is_empty());
125        // First 4 bytes are length
126        let len = u32::from_be_bytes([msg[0], msg[1], msg[2], msg[3]]);
127        assert_eq!(len as usize, msg.len() - 4);
128    }
129
130    #[test]
131    fn test_build_auth_message_with_credentials() {
132        let msg = build_auth_message(Some("admin"), Some("password")).unwrap();
133        assert!(!msg.is_empty());
134        // Should be larger than no-credentials message due to username
135        let no_cred_msg = build_auth_message(None, None).unwrap();
136        assert!(msg.len() > no_cred_msg.len());
137    }
138
139    #[test]
140    fn test_build_auth_message_user_only() {
141        let msg = build_auth_message(Some("testuser"), None).unwrap();
142        assert!(!msg.is_empty());
143        let len = u32::from_be_bytes([msg[0], msg[1], msg[2], msg[3]]);
144        assert_eq!(len as usize, msg.len() - 4);
145    }
146
147    #[test]
148    fn test_build_auth_message_contains_protocol_version() {
149        let msg = build_auth_message(None, None).unwrap();
150        // Bytes 4 and 5 should be protocol version (1)
151        assert_eq!(msg[4], PROTOCOL_VERSION);
152        assert_eq!(msg[5], PROTOCOL_VERSION);
153    }
154
155    #[test]
156    fn test_build_auth_message_contains_database() {
157        let msg = build_auth_message(None, None).unwrap();
158        // "database" string should be in the message
159        let msg_str = String::from_utf8_lossy(&msg);
160        assert!(msg_str.contains("database"));
161    }
162
163    #[test]
164    fn test_parse_response_handle_valid() {
165        // Status byte + 8 bytes handle
166        let data = vec![0u8, 0, 0, 0, 0, 0, 0, 0, 42];
167        let handle = parse_response_handle(&data).unwrap();
168        assert_eq!(handle, 42);
169    }
170
171    #[test]
172    fn test_parse_response_handle_negative() {
173        // Status byte + handle with high bit set
174        let data = vec![0u8, 255, 255, 255, 255, 255, 255, 255, 255];
175        let handle = parse_response_handle(&data).unwrap();
176        assert_eq!(handle, -1);
177    }
178
179    #[test]
180    fn test_parse_response_handle_too_short() {
181        let data = vec![0u8, 1, 2, 3]; // Only 4 bytes, need 9
182        let result = parse_response_handle(&data);
183        assert!(result.is_err());
184    }
185
186    #[test]
187    fn test_parse_auth_response_invalid_auth() {
188        // Version byte + auth status (non-zero = failure)
189        let mut data = vec![1u8, 1]; // version=1, auth=1 (failed)
190        // Pad with enough data to avoid read errors
191        data.extend_from_slice(&[0u8; 50]);
192        let result = parse_auth_response(&data);
193        assert!(matches!(result, Err(VoltError::AuthFailed)));
194    }
195
196    #[test]
197    fn test_ping_handle_constant() {
198        // PING_HANDLE should be a large positive value (1 << 62)
199        // 1 << 63 - 1 is parsed as 1 << (63-1) = 1 << 62 due to operator precedence
200        assert_eq!(PING_HANDLE, 1i64 << 62);
201    }
202}