1use std::net::SocketAddr;
4use std::sync::atomic::{AtomicU32, Ordering};
5use std::time::Duration;
6
7use bytes::Bytes;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::net::TcpStream;
10use tokio::sync::Mutex;
11use tokio::time::timeout;
12use tracing::{debug, error, info, trace};
13
14use crate::types::{Command, Response};
15use crate::{Error, Result};
16
17const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
19const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
21const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
23
24#[derive(Debug)]
26pub struct Connection {
27 stream: Mutex<TcpStream>,
29 addr: SocketAddr,
31 next_seq: AtomicU32,
33 connect_timeout: Duration,
35 request_timeout: Duration,
37}
38
39impl Connection {
40 pub async fn connect(addr: impl Into<SocketAddr>) -> Result<Self> {
42 Self::connect_with_timeout(addr, DEFAULT_CONNECT_TIMEOUT, DEFAULT_REQUEST_TIMEOUT).await
43 }
44
45 pub async fn connect_with_timeout(
47 addr: impl Into<SocketAddr>,
48 connect_timeout: Duration,
49 request_timeout: Duration,
50 ) -> Result<Self> {
51 let addr = addr.into();
52 info!("Connecting to VedDB server at {}", addr);
53
54 let stream = timeout(connect_timeout, TcpStream::connect(&addr))
55 .await
56 .map_err(Error::Timeout)??;
57
58 info!("Connected to VedDB server at {}", addr);
59
60 Ok(Self {
61 stream: Mutex::new(stream),
62 addr,
63 next_seq: AtomicU32::new(1),
64 connect_timeout,
65 request_timeout,
66 })
67 }
68
69 fn next_seq(&self) -> u32 {
71 self.next_seq.fetch_add(1, Ordering::SeqCst)
72 }
73
74 pub async fn execute(&self, cmd: Command) -> Result<Response> {
76 let seq = cmd.header.seq;
77 debug!("Executing command: {:?} (seq={})", cmd.header.opcode, seq);
78
79 let mut stream = self.stream.lock().await;
80
81 let cmd_bytes = cmd.to_bytes();
83 debug!("Sending command: {} bytes", cmd_bytes.len());
84 debug!("Command header: {:?}", &cmd_bytes[..std::cmp::min(24, cmd_bytes.len())]);
85
86 timeout(self.request_timeout, stream.write_all(&cmd_bytes))
87 .await
88 .map_err(Error::Timeout)??;
89
90 debug!("Command sent, flushing...");
91 timeout(self.request_timeout, stream.flush())
92 .await
93 .map_err(Error::Timeout)??;
94 debug!("Command flushed");
95
96 debug!("Reading response header (20 bytes)...");
98 let mut header_buf = [0u8; 20];
99 timeout(self.request_timeout, stream.read_exact(&mut header_buf))
100 .await
101 .map_err(Error::Timeout)??;
102 debug!("Response header received: {:?}", &header_buf[..8]);
103
104 let payload_len =
106 u32::from_le_bytes([header_buf[8], header_buf[9], header_buf[10], header_buf[11]]);
107
108 if payload_len as usize > MAX_FRAME_SIZE {
109 return Err(Error::Protocol(format!(
110 "Response too large: {} bytes (max: {})",
111 payload_len, MAX_FRAME_SIZE
112 )));
113 }
114
115 let mut payload = vec![0u8; payload_len as usize];
117 if payload_len > 0 {
118 timeout(self.request_timeout, stream.read_exact(&mut payload))
119 .await
120 .map_err(Error::Timeout)??;
121 }
122
123 let mut response_bytes = Vec::with_capacity(20 + payload_len as usize);
125 response_bytes.extend_from_slice(&header_buf);
126 response_bytes.extend_from_slice(&payload);
127
128 let response = Response::from_bytes(&response_bytes)
129 .map_err(|e| Error::Protocol(format!("Invalid response: {}", e)))?;
130
131 if response.header.seq != seq {
133 return Err(Error::Protocol(format!(
134 "Sequence number mismatch: expected {}, got {}",
135 seq, response.header.seq
136 )));
137 }
138
139 if !response.is_ok() {
141 let status = response.status();
142 let error_msg = String::from_utf8_lossy(&response.payload).into_owned();
143 return Err(Error::Server(format!(
144 "Server error: {:?}: {}",
145 status, error_msg
146 )));
147 }
148
149 Ok(response)
150 }
151
152 pub async fn ping(&self) -> Result<()> {
154 let seq = self.next_seq();
155 let cmd = Command::ping(seq);
156 self.execute(cmd).await?;
157 Ok(())
158 }
159
160 pub async fn set<K, V>(&self, key: K, value: V) -> Result<()>
162 where
163 K: Into<Bytes>,
164 V: Into<Bytes>,
165 {
166 let seq = self.next_seq();
167 let cmd = Command::set(seq, key, value);
168 self.execute(cmd).await?;
169 Ok(())
170 }
171
172 pub async fn get<K>(&self, key: K) -> Result<Bytes>
174 where
175 K: Into<Bytes>,
176 {
177 let seq = self.next_seq();
178 let cmd = Command::get(seq, key);
179 let response = self.execute(cmd).await?;
180 Ok(response.payload)
181 }
182
183 pub async fn delete<K>(&self, key: K) -> Result<()>
185 where
186 K: Into<Bytes>,
187 {
188 let seq = self.next_seq();
189 let cmd = Command::delete(seq, key);
190 self.execute(cmd).await?;
191 Ok(())
192 }
193
194 pub async fn cas<K, V>(&self, key: K, expected_version: u64, value: V) -> Result<()>
196 where
197 K: Into<Bytes>,
198 V: Into<Bytes>,
199 {
200 let seq = self.next_seq();
201 let cmd = Command::cas(seq, key, expected_version, value);
202 self.execute(cmd).await?;
203 Ok(())
204 }
205}
206
207#[derive(Clone, Debug)]
209pub struct Client {
210 pool: ConnectionPool,
212}
213
214impl Client {
215 pub async fn connect(addr: impl Into<SocketAddr>) -> Result<Self> {
217 let pool = ConnectionPool::new(addr, 1).await?;
218 Ok(Self { pool })
219 }
220
221 pub async fn with_pool_size(addr: impl Into<SocketAddr>, pool_size: usize) -> Result<Self> {
223 let pool = ConnectionPool::new(addr, pool_size).await?;
224 Ok(Self { pool })
225 }
226
227 pub async fn ping(&self) -> Result<()> {
229 self.pool.get().await?.ping().await
230 }
231
232 pub async fn set<K, V>(&self, key: K, value: V) -> Result<()>
234 where
235 K: Into<Bytes>,
236 V: Into<Bytes>,
237 {
238 self.pool.get().await?.set(key, value).await
239 }
240
241 pub async fn get<K>(&self, key: K) -> Result<Bytes>
243 where
244 K: Into<Bytes>,
245 {
246 self.pool.get().await?.get(key).await
247 }
248
249 pub async fn delete<K>(&self, key: K) -> Result<()>
251 where
252 K: Into<Bytes>,
253 {
254 self.pool.get().await?.delete(key).await
255 }
256
257 pub async fn cas<K, V>(&self, key: K, expected_version: u64, value: V) -> Result<()>
259 where
260 K: Into<Bytes>,
261 V: Into<Bytes>,
262 {
263 self.pool
264 .get()
265 .await?
266 .cas(key, expected_version, value)
267 .await
268 }
269
270 pub async fn list_keys(&self) -> Result<Vec<String>> {
272 let conn = self.pool.get().await?;
273 let cmd = Command::fetch(conn.next_seq(), Bytes::new());
274 let response = conn.execute(cmd).await?;
275
276 if !response.is_ok() {
277 return Err(Error::Protocol(format!("List keys failed: {:?}", response.status())));
278 }
279
280 let keys_str = String::from_utf8_lossy(&response.payload);
282 let keys: Vec<String> = keys_str
283 .lines()
284 .filter(|s| !s.is_empty())
285 .map(|s| s.to_string())
286 .collect();
287
288 Ok(keys)
289 }
290}
291
292#[derive(Debug, Clone)]
294pub struct ConnectionPool {
295 addr: SocketAddr,
297 pool: async_channel::Receiver<Connection>,
299 pool_sender: async_channel::Sender<Connection>,
301 size: usize,
303}
304
305impl ConnectionPool {
306 pub async fn new(addr: impl Into<SocketAddr>, size: usize) -> Result<Self> {
308 let addr = addr.into();
309 let (tx, rx) = async_channel::bounded(size);
310
311 for _ in 0..size {
313 let conn = Connection::connect(addr).await?;
314 tx.send(conn)
315 .await
316 .map_err(|e| Error::Connection(e.to_string()))?;
317 }
318
319 Ok(Self {
320 addr,
321 pool: rx,
322 pool_sender: tx,
323 size,
324 })
325 }
326
327 pub async fn get(&self) -> Result<ConnectionGuard> {
329 let conn = self
330 .pool
331 .recv()
332 .await
333 .map_err(|e| Error::Connection(e.to_string()))?;
334 Ok(ConnectionGuard {
335 conn: Some(conn),
336 pool: self.pool_sender.clone(),
337 })
338 }
339
340 pub fn size(&self) -> usize {
342 self.size
343 }
344}
345
346pub struct ConnectionGuard {
348 conn: Option<Connection>,
350 pool: async_channel::Sender<Connection>,
352}
353
354impl ConnectionGuard {
355 pub fn connection(&self) -> &Connection {
357 self.conn.as_ref().unwrap()
358 }
359
360 pub fn connection_mut(&mut self) -> &mut Connection {
362 self.conn.as_mut().unwrap()
363 }
364}
365
366impl Drop for ConnectionGuard {
367 fn drop(&mut self) {
368 if let Some(conn) = self.conn.take() {
369 let pool = self.pool.clone();
370 tokio::spawn(async move {
371 if let Err(e) = pool.send(conn).await {
372 error!("Failed to return connection to pool: {}", e);
373 }
374 });
375 }
376 }
377}
378
379impl std::ops::Deref for ConnectionGuard {
380 type Target = Connection;
381
382 fn deref(&self) -> &Self::Target {
383 self.connection()
384 }
385}
386
387impl std::ops::DerefMut for ConnectionGuard {
388 fn deref_mut(&mut self) -> &mut Self::Target {
389 self.connection_mut()
390 }
391}
392
393#[derive(Debug, Clone)]
395pub struct ClientBuilder {
396 addr: SocketAddr,
398 pool_size: usize,
400 connect_timeout: Duration,
402 request_timeout: Duration,
404}
405
406impl Default for ClientBuilder {
407 fn default() -> Self {
408 Self {
409 addr: ([127, 0, 0, 1], 50051).into(),
410 pool_size: 10,
411 connect_timeout: DEFAULT_CONNECT_TIMEOUT,
412 request_timeout: DEFAULT_REQUEST_TIMEOUT,
413 }
414 }
415}
416
417impl ClientBuilder {
418 pub fn new() -> Self {
420 Self::default()
421 }
422
423 pub fn addr(mut self, addr: impl Into<SocketAddr>) -> Self {
425 self.addr = addr.into();
426 self
427 }
428
429 pub fn pool_size(mut self, size: usize) -> Self {
431 self.pool_size = size;
432 self
433 }
434
435 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
437 self.connect_timeout = timeout;
438 self
439 }
440
441 pub fn request_timeout(mut self, timeout: Duration) -> Self {
443 self.request_timeout = timeout;
444 self
445 }
446
447 pub async fn connect(self) -> Result<Client> {
449 let pool = ConnectionPool::new(self.addr, self.pool_size).await?;
450 Ok(Client { pool })
451 }
452}