veddb_client/
connection.rs

1//! Connection handling for VedDB client
2
3use std::net::SocketAddr;
4use std::sync::atomic::{AtomicU32, Ordering};
5use std::time::Duration;
6
7use bytes::Bytes;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::net::TcpStream;
10use tokio::sync::Mutex;
11use tokio::time::timeout;
12use tracing::{debug, error, info, trace};
13
14use crate::types::{Command, Response};
15use crate::{Error, Result};
16
17/// Default connection timeout
18const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
19/// Default request timeout
20const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
21/// Maximum frame size (16MB)
22const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
23
24/// A connection to a VedDB server
25#[derive(Debug)]
26pub struct Connection {
27    /// The underlying TCP stream
28    stream: Mutex<TcpStream>,
29    /// Server address
30    addr: SocketAddr,
31    /// Next sequence number
32    next_seq: AtomicU32,
33    /// Connection timeout
34    connect_timeout: Duration,
35    /// Request timeout
36    request_timeout: Duration,
37}
38
39impl Connection {
40    /// Create a new connection to the specified address
41    pub async fn connect(addr: impl Into<SocketAddr>) -> Result<Self> {
42        Self::connect_with_timeout(addr, DEFAULT_CONNECT_TIMEOUT, DEFAULT_REQUEST_TIMEOUT).await
43    }
44
45    /// Create a new connection with custom timeouts
46    pub async fn connect_with_timeout(
47        addr: impl Into<SocketAddr>,
48        connect_timeout: Duration,
49        request_timeout: Duration,
50    ) -> Result<Self> {
51        let addr = addr.into();
52        info!("Connecting to VedDB server at {}", addr);
53
54        let stream = timeout(connect_timeout, TcpStream::connect(&addr))
55            .await
56            .map_err(Error::Timeout)??;
57
58        info!("Connected to VedDB server at {}", addr);
59
60        Ok(Self {
61            stream: Mutex::new(stream),
62            addr,
63            next_seq: AtomicU32::new(1),
64            connect_timeout,
65            request_timeout,
66        })
67    }
68
69    /// Get the next sequence number
70    fn next_seq(&self) -> u32 {
71        self.next_seq.fetch_add(1, Ordering::SeqCst)
72    }
73
74    /// Execute a command and return the response
75    pub async fn execute(&self, cmd: Command) -> Result<Response> {
76        let seq = cmd.header.seq;
77        debug!("Executing command: {:?} (seq={})", cmd.header.opcode, seq);
78
79        let mut stream = self.stream.lock().await;
80
81        // Send the command
82        let cmd_bytes = cmd.to_bytes();
83        debug!("Sending command: {} bytes", cmd_bytes.len());
84        debug!("Command header: {:?}", &cmd_bytes[..std::cmp::min(24, cmd_bytes.len())]);
85
86        timeout(self.request_timeout, stream.write_all(&cmd_bytes))
87            .await
88            .map_err(Error::Timeout)??;
89        
90        debug!("Command sent, flushing...");
91        timeout(self.request_timeout, stream.flush())
92            .await
93            .map_err(Error::Timeout)??;
94        debug!("Command flushed");
95
96        // Read the response header (20 bytes)
97        debug!("Reading response header (20 bytes)...");
98        let mut header_buf = [0u8; 20];
99        timeout(self.request_timeout, stream.read_exact(&mut header_buf))
100            .await
101            .map_err(Error::Timeout)??;
102        debug!("Response header received: {:?}", &header_buf[..8]);
103
104        // Parse the header
105        let payload_len =
106            u32::from_le_bytes([header_buf[8], header_buf[9], header_buf[10], header_buf[11]]);
107
108        if payload_len as usize > MAX_FRAME_SIZE {
109            return Err(Error::Protocol(format!(
110                "Response too large: {} bytes (max: {})",
111                payload_len, MAX_FRAME_SIZE
112            )));
113        }
114
115        // Read the payload
116        let mut payload = vec![0u8; payload_len as usize];
117        if payload_len > 0 {
118            timeout(self.request_timeout, stream.read_exact(&mut payload))
119                .await
120                .map_err(Error::Timeout)??;
121        }
122
123        // Combine header and payload for parsing
124        let mut response_bytes = Vec::with_capacity(20 + payload_len as usize);
125        response_bytes.extend_from_slice(&header_buf);
126        response_bytes.extend_from_slice(&payload);
127
128        let response = Response::from_bytes(&response_bytes)
129            .map_err(|e| Error::Protocol(format!("Invalid response: {}", e)))?;
130
131        // Verify sequence number
132        if response.header.seq != seq {
133            return Err(Error::Protocol(format!(
134                "Sequence number mismatch: expected {}, got {}",
135                seq, response.header.seq
136            )));
137        }
138
139        // Check for server errors
140        if !response.is_ok() {
141            let status = response.status();
142            let error_msg = String::from_utf8_lossy(&response.payload).into_owned();
143            return Err(Error::Server(format!(
144                "Server error: {:?}: {}",
145                status, error_msg
146            )));
147        }
148
149        Ok(response)
150    }
151
152    /// Ping the server
153    pub async fn ping(&self) -> Result<()> {
154        let seq = self.next_seq();
155        let cmd = Command::ping(seq);
156        self.execute(cmd).await?;
157        Ok(())
158    }
159
160    /// Set a key-value pair
161    pub async fn set<K, V>(&self, key: K, value: V) -> Result<()>
162    where
163        K: Into<Bytes>,
164        V: Into<Bytes>,
165    {
166        let seq = self.next_seq();
167        let cmd = Command::set(seq, key, value);
168        self.execute(cmd).await?;
169        Ok(())
170    }
171
172    /// Get a value by key
173    pub async fn get<K>(&self, key: K) -> Result<Bytes>
174    where
175        K: Into<Bytes>,
176    {
177        let seq = self.next_seq();
178        let cmd = Command::get(seq, key);
179        let response = self.execute(cmd).await?;
180        Ok(response.payload)
181    }
182
183    /// Delete a key
184    pub async fn delete<K>(&self, key: K) -> Result<()>
185    where
186        K: Into<Bytes>,
187    {
188        let seq = self.next_seq();
189        let cmd = Command::delete(seq, key);
190        self.execute(cmd).await?;
191        Ok(())
192    }
193
194    /// Compare and swap a value
195    pub async fn cas<K, V>(&self, key: K, expected_version: u64, value: V) -> Result<()>
196    where
197        K: Into<Bytes>,
198        V: Into<Bytes>,
199    {
200        let seq = self.next_seq();
201        let cmd = Command::cas(seq, key, expected_version, value);
202        self.execute(cmd).await?;
203        Ok(())
204    }
205}
206
207/// A client for interacting with a VedDB server
208#[derive(Clone, Debug)]
209pub struct Client {
210    /// The connection pool
211    pool: ConnectionPool,
212}
213
214impl Client {
215    /// Create a new client connected to the specified address
216    pub async fn connect(addr: impl Into<SocketAddr>) -> Result<Self> {
217        let pool = ConnectionPool::new(addr, 1).await?;
218        Ok(Self { pool })
219    }
220
221    /// Create a new client with a connection pool of the specified size
222    pub async fn with_pool_size(addr: impl Into<SocketAddr>, pool_size: usize) -> Result<Self> {
223        let pool = ConnectionPool::new(addr, pool_size).await?;
224        Ok(Self { pool })
225    }
226
227    /// Ping the server
228    pub async fn ping(&self) -> Result<()> {
229        self.pool.get().await?.ping().await
230    }
231
232    /// Set a key-value pair
233    pub async fn set<K, V>(&self, key: K, value: V) -> Result<()>
234    where
235        K: Into<Bytes>,
236        V: Into<Bytes>,
237    {
238        self.pool.get().await?.set(key, value).await
239    }
240
241    /// Get a value by key
242    pub async fn get<K>(&self, key: K) -> Result<Bytes>
243    where
244        K: Into<Bytes>,
245    {
246        self.pool.get().await?.get(key).await
247    }
248
249    /// Delete a key
250    pub async fn delete<K>(&self, key: K) -> Result<()>
251    where
252        K: Into<Bytes>,
253    {
254        self.pool.get().await?.delete(key).await
255    }
256
257    /// Compare and swap a value
258    pub async fn cas<K, V>(&self, key: K, expected_version: u64, value: V) -> Result<()>
259    where
260        K: Into<Bytes>,
261        V: Into<Bytes>,
262    {
263        self.pool
264            .get()
265            .await?
266            .cas(key, expected_version, value)
267            .await
268    }
269
270    /// List all keys (uses Fetch opcode 0x09)
271    pub async fn list_keys(&self) -> Result<Vec<String>> {
272        let conn = self.pool.get().await?;
273        let cmd = Command::fetch(conn.next_seq(), Bytes::new());
274        let response = conn.execute(cmd).await?;
275        
276        if !response.is_ok() {
277            return Err(Error::Protocol(format!("List keys failed: {:?}", response.status())));
278        }
279        
280        // Parse newline-separated keys
281        let keys_str = String::from_utf8_lossy(&response.payload);
282        let keys: Vec<String> = keys_str
283            .lines()
284            .filter(|s| !s.is_empty())
285            .map(|s| s.to_string())
286            .collect();
287        
288        Ok(keys)
289    }
290}
291
292/// A connection pool for managing multiple connections to a VedDB server
293#[derive(Debug, Clone)]
294pub struct ConnectionPool {
295    /// The server address
296    addr: SocketAddr,
297    /// The connection pool receiver
298    pool: async_channel::Receiver<Connection>,
299    /// The connection pool sender
300    pool_sender: async_channel::Sender<Connection>,
301    /// The number of connections in the pool
302    size: usize,
303}
304
305impl ConnectionPool {
306    /// Create a new connection pool
307    pub async fn new(addr: impl Into<SocketAddr>, size: usize) -> Result<Self> {
308        let addr = addr.into();
309        let (tx, rx) = async_channel::bounded(size);
310
311        // Initialize connections
312        for _ in 0..size {
313            let conn = Connection::connect(addr).await?;
314            tx.send(conn)
315                .await
316                .map_err(|e| Error::Connection(e.to_string()))?;
317        }
318
319        Ok(Self {
320            addr,
321            pool: rx,
322            pool_sender: tx,
323            size,
324        })
325    }
326
327    /// Get a connection from the pool
328    pub async fn get(&self) -> Result<ConnectionGuard> {
329        let conn = self
330            .pool
331            .recv()
332            .await
333            .map_err(|e| Error::Connection(e.to_string()))?;
334        Ok(ConnectionGuard {
335            conn: Some(conn),
336            pool: self.pool_sender.clone(),
337        })
338    }
339
340    /// Get the number of connections in the pool
341    pub fn size(&self) -> usize {
342        self.size
343    }
344}
345
346/// A guard that returns a connection to the pool when dropped
347pub struct ConnectionGuard {
348    /// The connection
349    conn: Option<Connection>,
350    /// The connection pool
351    pool: async_channel::Sender<Connection>,
352}
353
354impl ConnectionGuard {
355    /// Get a reference to the underlying connection
356    pub fn connection(&self) -> &Connection {
357        self.conn.as_ref().unwrap()
358    }
359
360    /// Get a mutable reference to the underlying connection
361    pub fn connection_mut(&mut self) -> &mut Connection {
362        self.conn.as_mut().unwrap()
363    }
364}
365
366impl Drop for ConnectionGuard {
367    fn drop(&mut self) {
368        if let Some(conn) = self.conn.take() {
369            let pool = self.pool.clone();
370            tokio::spawn(async move {
371                if let Err(e) = pool.send(conn).await {
372                    error!("Failed to return connection to pool: {}", e);
373                }
374            });
375        }
376    }
377}
378
379impl std::ops::Deref for ConnectionGuard {
380    type Target = Connection;
381
382    fn deref(&self) -> &Self::Target {
383        self.connection()
384    }
385}
386
387impl std::ops::DerefMut for ConnectionGuard {
388    fn deref_mut(&mut self) -> &mut Self::Target {
389        self.connection_mut()
390    }
391}
392
393/// A builder for configuring and creating a client
394#[derive(Debug, Clone)]
395pub struct ClientBuilder {
396    /// The server address
397    addr: SocketAddr,
398    /// The connection pool size
399    pool_size: usize,
400    /// The connection timeout
401    connect_timeout: Duration,
402    /// The request timeout
403    request_timeout: Duration,
404}
405
406impl Default for ClientBuilder {
407    fn default() -> Self {
408        Self {
409            addr: ([127, 0, 0, 1], 50051).into(),
410            pool_size: 10,
411            connect_timeout: DEFAULT_CONNECT_TIMEOUT,
412            request_timeout: DEFAULT_REQUEST_TIMEOUT,
413        }
414    }
415}
416
417impl ClientBuilder {
418    /// Create a new builder with default settings
419    pub fn new() -> Self {
420        Self::default()
421    }
422
423    /// Set the server address
424    pub fn addr(mut self, addr: impl Into<SocketAddr>) -> Self {
425        self.addr = addr.into();
426        self
427    }
428
429    /// Set the connection pool size
430    pub fn pool_size(mut self, size: usize) -> Self {
431        self.pool_size = size;
432        self
433    }
434
435    /// Set the connection timeout
436    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
437        self.connect_timeout = timeout;
438        self
439    }
440
441    /// Set the request timeout
442    pub fn request_timeout(mut self, timeout: Duration) -> Self {
443        self.request_timeout = timeout;
444        self
445    }
446
447    /// Build and connect the client
448    pub async fn connect(self) -> Result<Client> {
449        let pool = ConnectionPool::new(self.addr, self.pool_size).await?;
450        Ok(Client { pool })
451    }
452}