Skip to main content

voltdb_client_rust/
async_node.rs

1#![cfg(feature = "tokio")]
2use std::fmt::{Debug, Formatter};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicI64, AtomicUsize, Ordering};
5use std::time::{Duration, Instant};
6
7use crate::encode::{Value, VoltError};
8use crate::node::{ConnInfo, NodeOpt};
9use crate::procedure_invocation::new_procedure_invocation;
10use crate::protocol::{PING_HANDLE, build_auth_message, parse_auth_response};
11use crate::response::VoltResponseInfo;
12use crate::table::{VoltTable, new_volt_table};
13use crate::volt_param;
14use byteorder::{BigEndian, ByteOrder};
15use bytes::{Buf, BytesMut};
16use dashmap::DashMap;
17use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
18use tokio::net::TcpStream;
19use tokio::sync::{mpsc, watch};
20use tokio::time::timeout;
21
22// ============================================================================
23// Logging macros - use tracing if available, otherwise no-op
24// ============================================================================
25
26#[cfg(feature = "tracing")]
27#[allow(unused_macros)]
28macro_rules! async_node_trace {
29    ($($arg:tt)*) => { tracing::trace!($($arg)*) };
30}
31#[cfg(not(feature = "tracing"))]
32#[allow(unused_macros)]
33macro_rules! async_node_trace {
34    ($($arg:tt)*) => {};
35}
36
37#[cfg(feature = "tracing")]
38macro_rules! async_node_debug {
39    ($($arg:tt)*) => { tracing::debug!($($arg)*) };
40}
41#[cfg(not(feature = "tracing"))]
42macro_rules! async_node_debug {
43    ($($arg:tt)*) => {};
44}
45
46#[cfg(feature = "tracing")]
47macro_rules! async_node_warn {
48    ($($arg:tt)*) => { tracing::warn!($($arg)*) };
49}
50#[cfg(not(feature = "tracing"))]
51macro_rules! async_node_warn {
52    ($($arg:tt)*) => {};
53}
54
55#[cfg(feature = "tracing")]
56macro_rules! async_node_error {
57    ($($arg:tt)*) => { tracing::error!($($arg)*) };
58}
59#[cfg(not(feature = "tracing"))]
60macro_rules! async_node_error {
61    ($($arg:tt)*) => {};
62}
63
64/// Configuration constants
65const MAX_MESSAGE_SIZE: usize = 50 * 1024 * 1024; // 50MB message size limit
66const WRITE_BUFFER_SIZE: usize = 1024; // Write queue capacity
67const BATCH_WRITE_THRESHOLD: usize = 8192; // Batch write threshold
68const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30); // Default timeout
69#[allow(dead_code)]
70const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(60); // TCP keepalive interval
71
72/// Write command for the writer task
73#[allow(dead_code)]
74enum WriteCommand {
75    Data(Vec<u8>),
76    Flush,
77}
78
79/// Async network request tracking
80#[allow(dead_code)]
81struct AsyncNetworkRequest {
82    handle: i64,
83    query: bool,
84    sync: bool,
85    num_bytes: i32,
86    channel: mpsc::Sender<VoltTable>,
87    created_at: Instant, // Used for timeout detection
88}
89
90impl Debug for AsyncNetworkRequest {
91    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
92        f.debug_struct("AsyncNetworkRequest")
93            .field("handle", &self.handle)
94            .field("query", &self.query)
95            .field("age_ms", &self.created_at.elapsed().as_millis())
96            .finish()
97    }
98}
99
100/// Async VoltDB connection node
101pub struct AsyncNode {
102    /// Write command sender channel
103    write_tx: mpsc::Sender<WriteCommand>,
104    /// Connection info from authentication
105    info: ConnInfo,
106    /// Request map (using DashMap to reduce lock contention)
107    requests: Arc<DashMap<i64, AsyncNetworkRequest>>,
108    /// Stop signal for background tasks
109    stop: Arc<watch::Sender<bool>>,
110    /// Request sequence number counter
111    counter: Arc<AtomicI64>,
112    /// Pending request count (used for load balancing)
113    pending_requests: Arc<AtomicUsize>,
114    /// Handles for background tasks (writer, reader, timeout checker)
115    task_handles: std::sync::Mutex<Vec<tokio::task::JoinHandle<()>>>,
116}
117
118impl Debug for AsyncNode {
119    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
120        f.debug_struct("AsyncNode")
121            .field(
122                "pending_requests",
123                &self.pending_requests.load(Ordering::Relaxed),
124            )
125            .field("total_requests", &self.requests.len())
126            .finish()
127    }
128}
129
130impl Drop for AsyncNode {
131    fn drop(&mut self) {
132        // Signal background tasks to stop
133        let _ = self.stop.send(true);
134    }
135}
136
137impl AsyncNode {
138    /// Create a new async connection to VoltDB server
139    pub async fn new(opt: NodeOpt) -> Result<AsyncNode, VoltError> {
140        let addr = format!("{}:{}", opt.ip_port.ip_host, opt.ip_port.port);
141
142        // Build authentication message
143        let auth_msg = build_auth_message(opt.user.as_deref(), opt.pass.as_deref())?;
144
145        // Async connect
146        let mut stream = TcpStream::connect(&addr).await?;
147
148        // TCP optimization configuration
149        stream.set_nodelay(true)?; // Disable Nagle algorithm to reduce latency
150        // if let Err(e) = stream.set_keepalive(Some(KEEPALIVE_INTERVAL)) {
151        //     eprintln!("Warning: Failed to set keepalive: {}", e);
152        // }
153
154        // Async authentication handshake
155        stream.write_all(&auth_msg).await?;
156        stream.flush().await?;
157
158        // Read authentication response
159        let mut len_buf = [0u8; 4];
160        stream.read_exact(&mut len_buf).await?;
161        let read = BigEndian::read_u32(&len_buf) as usize;
162
163        let mut all = vec![0; read];
164        stream.read_exact(&mut all).await?;
165
166        // Parse authentication response
167        let info = parse_auth_response(&all)?;
168
169        // Split into read and write halves
170        let (read_half, write_half) = tokio::io::split(stream);
171
172        // Create channels
173        let requests = Arc::new(DashMap::new());
174        let (stop_tx, stop_rx) = watch::channel(false);
175        let (write_tx, write_rx) = mpsc::channel(WRITE_BUFFER_SIZE);
176
177        let node = AsyncNode {
178            stop: Arc::new(stop_tx),
179            write_tx,
180            info,
181            requests: requests.clone(),
182            counter: Arc::new(AtomicI64::new(1)),
183            pending_requests: Arc::new(AtomicUsize::new(0)),
184            task_handles: std::sync::Mutex::new(Vec::with_capacity(3)),
185        };
186
187        // Start background tasks
188        node.spawn_writer(write_half, write_rx, stop_rx.clone());
189        node.spawn_reader(read_half, stop_rx.clone());
190        node.spawn_timeout_checker(stop_rx);
191
192        Ok(node)
193    }
194
195    /// Get the next sequence number for request tracking
196    #[inline]
197    pub fn get_sequence(&self) -> i64 {
198        self.counter.fetch_add(1, Ordering::Relaxed)
199    }
200
201    /// Get the current pending request count (used for load balancing)
202    #[inline]
203    pub fn pending_count(&self) -> usize {
204        self.pending_requests.load(Ordering::Relaxed)
205    }
206
207    /// Get connection info from authentication
208    pub fn conn_info(&self) -> &ConnInfo {
209        &self.info
210    }
211
212    /// List all stored procedures available in the database
213    pub async fn list_procedures(&self) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
214        self.call_sp("@SystemCatalog", volt_param!("PROCEDURES"))
215            .await
216    }
217
218    /// Call a stored procedure with parameters.
219    /// Returns a receiver for the response. Use `async_block_for_result` or
220    /// `async_block_for_result_with_timeout` to await the result.
221    pub async fn call_sp(
222        &self,
223        query: &str,
224        param: Vec<&dyn Value>,
225    ) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
226        let req = self.get_sequence();
227        let mut proc = new_procedure_invocation(req, false, &param, query);
228
229        // Create response channel
230        let (tx, rx) = mpsc::channel(1);
231
232        let seq = AsyncNetworkRequest {
233            query: true,
234            handle: req,
235            num_bytes: proc.slen,
236            sync: true,
237            channel: tx,
238            created_at: Instant::now(),
239        };
240
241        // Insert into request map
242        self.requests.insert(req, seq);
243        self.pending_requests.fetch_add(1, Ordering::Relaxed);
244
245        // Send request data
246        let bs = proc.bytes();
247        self.write_tx
248            .send(WriteCommand::Data(bs))
249            .await
250            .map_err(|_| VoltError::connection_closed())?;
251
252        Ok(rx)
253    }
254
255    /// Call a stored procedure with custom timeout.
256    /// Convenience method that sends the request and awaits the response with a timeout.
257    pub async fn call_sp_with_timeout(
258        &self,
259        query: &str,
260        param: Vec<&dyn Value>,
261        timeout_duration: Duration,
262    ) -> Result<VoltTable, VoltError> {
263        let mut rx = self.call_sp(query, param).await?;
264        async_block_for_result_with_timeout(&mut rx, timeout_duration).await
265    }
266
267    /// Upload a JAR file containing stored procedure classes
268    pub async fn upload_jar(&self, bs: Vec<u8>) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
269        self.call_sp("@UpdateClasses", volt_param!(bs, "")).await
270    }
271
272    /// Execute an ad-hoc SQL query
273    pub async fn query(&self, sql: &str) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
274        let mut zero_vec: Vec<&dyn Value> = Vec::new();
275        zero_vec.push(&sql);
276        self.call_sp("@AdHoc", zero_vec).await
277    }
278
279    /// Send a ping to keep the connection alive
280    pub async fn ping(&self) -> Result<(), VoltError> {
281        let zero_vec: Vec<&dyn Value> = Vec::new();
282        let mut proc = new_procedure_invocation(PING_HANDLE, false, &zero_vec, "@Ping");
283        let bs = proc.bytes();
284
285        self.write_tx
286            .send(WriteCommand::Data(bs))
287            .await
288            .map_err(|_| VoltError::connection_closed())?;
289
290        Ok(())
291    }
292
293    /// Shutdown the connection gracefully.
294    /// Signals background tasks to stop and waits for them to finish.
295    pub async fn shutdown(&self) -> Result<(), VoltError> {
296        let _ = self.stop.send(true);
297        let handles: Vec<_> = self.task_handles.lock().unwrap().drain(..).collect();
298        for handle in handles {
299            let _ = tokio::time::timeout(Duration::from_secs(5), handle).await;
300        }
301        Ok(())
302    }
303
304    /// Spawn the writer task (supports batch write optimization)
305    fn spawn_writer(
306        &self,
307        mut write_half: WriteHalf<TcpStream>,
308        mut write_rx: mpsc::Receiver<WriteCommand>,
309        mut stop_rx: watch::Receiver<bool>,
310    ) {
311        let handle = tokio::spawn(async move {
312            let mut batch_buffer = Vec::with_capacity(BATCH_WRITE_THRESHOLD * 2);
313
314            loop {
315                tokio::select! {
316                    _ = stop_rx.changed() => {
317                        if *stop_rx.borrow() {
318                            break;
319                        }
320                    }
321                    cmd = write_rx.recv() => {
322                        match cmd {
323                            Some(WriteCommand::Data(bytes)) => {
324                                batch_buffer.extend_from_slice(&bytes);
325
326                                // Try to batch collect more data
327                                while batch_buffer.len() < BATCH_WRITE_THRESHOLD {
328                                    match write_rx.try_recv() {
329                                        Ok(WriteCommand::Data(more_bytes)) => {
330                                            batch_buffer.extend_from_slice(&more_bytes);
331                                        }
332                                        Ok(WriteCommand::Flush) => break,
333                                        Err(_) => break,
334                                    }
335                                }
336
337                                // Batch write
338                                if let Err(_e) = write_half.write_all(&batch_buffer).await {
339                                    async_node_error!(error = %_e, "write error");
340                                    break;
341                                }
342                                batch_buffer.clear();
343                            }
344                            Some(WriteCommand::Flush) => {
345                                if !batch_buffer.is_empty() {
346                                    if let Err(_e) = write_half.write_all(&batch_buffer).await {
347                                        async_node_error!(error = %_e, "flush error");
348                                        break;
349                                    }
350                                    batch_buffer.clear();
351                                }
352                                let _ = write_half.flush().await;
353                            }
354                            None => break,
355                        }
356                    }
357                }
358            }
359
360            async_node_debug!("writer task terminated");
361        });
362        self.task_handles.lock().unwrap().push(handle);
363    }
364
365    /// Spawn the reader task for receiving responses
366    fn spawn_reader(&self, mut read_half: ReadHalf<TcpStream>, mut stop_rx: watch::Receiver<bool>) {
367        let requests = Arc::clone(&self.requests);
368        let pending_requests = Arc::clone(&self.pending_requests);
369
370        let handle = tokio::spawn(async move {
371            let reason = loop {
372                tokio::select! {
373                    _ = stop_rx.changed() => {
374                        if *stop_rx.borrow() {
375                            break "shutdown requested";
376                        }
377                    }
378                    result = Self::async_job(&mut read_half, &requests, &pending_requests) => {
379                        if let Err(_e) = result {
380                            if !*stop_rx.borrow() {
381                                async_node_error!(error = %_e, "read error");
382                            }
383                            break "connection error";
384                        }
385                    }
386                }
387            };
388
389            // Cleanup all pending requests
390            Self::cleanup_requests(&requests, &pending_requests, reason).await;
391        });
392        self.task_handles.lock().unwrap().push(handle);
393    }
394
395    /// Spawn the timeout checker task for cleaning up stale requests
396    fn spawn_timeout_checker(&self, mut stop_rx: watch::Receiver<bool>) {
397        let requests = Arc::clone(&self.requests);
398        let pending_requests = Arc::clone(&self.pending_requests);
399
400        let handle = tokio::spawn(async move {
401            let mut interval = tokio::time::interval(Duration::from_secs(5));
402
403            loop {
404                tokio::select! {
405                    _ = stop_rx.changed() => {
406                        if *stop_rx.borrow() {
407                            break;
408                        }
409                    }
410                    _ = interval.tick() => {
411                        let now = Instant::now();
412                        let mut expired = Vec::new();
413
414                        // Find expired requests
415                        for entry in requests.iter() {
416                            let age = now.duration_since(entry.created_at);
417                            if age > DEFAULT_TIMEOUT {
418                                expired.push(*entry.key());
419                            }
420                        }
421
422                        // Cleanup expired requests
423                        for handle in expired {
424                            if let Some((_, _req)) = requests.remove(&handle) {
425                                pending_requests.fetch_sub(1, Ordering::Relaxed);
426                                async_node_warn!(
427                                    handle = handle,
428                                    elapsed = ?now.duration_since(_req.created_at),
429                                    "request timed out"
430                                );
431                                // Channel drop will notify the caller
432                            }
433                        }
434                    }
435                }
436            }
437        });
438        self.task_handles.lock().unwrap().push(handle);
439    }
440
441    /// Read and process a single response from the server
442    async fn async_job(
443        tcp: &mut ReadHalf<TcpStream>,
444        requests: &Arc<DashMap<i64, AsyncNetworkRequest>>,
445        pending_requests: &Arc<AtomicUsize>,
446    ) -> Result<(), VoltError> {
447        // Read message length
448        let mut len_buf = [0u8; 4];
449        tcp.read_exact(&mut len_buf).await?;
450        let msg_len = BigEndian::read_u32(&len_buf) as usize;
451
452        // Safety check for message size
453        if msg_len > MAX_MESSAGE_SIZE {
454            return Err(VoltError::MessageTooLarge(msg_len));
455        }
456
457        if msg_len == 0 {
458            return Ok(());
459        }
460
461        // Use BytesMut to reduce memory copying
462        let mut buf = BytesMut::with_capacity(msg_len);
463        buf.resize(msg_len, 0);
464        tcp.read_exact(&mut buf).await?;
465
466        // Parse response header
467        let _ = buf.get_u8();
468        let handle = buf.get_i64();
469
470        // Ping response - just return
471        if handle == PING_HANDLE {
472            return Ok(());
473        }
474
475        // Route response to waiting caller
476        if let Some((_, req)) = requests.remove(&handle) {
477            pending_requests.fetch_sub(1, Ordering::Relaxed);
478
479            // Freeze buffer so it can be safely moved across tasks
480            let frozen_buf = buf.freeze();
481
482            // Parse in a separate task to avoid blocking the read loop
483            tokio::spawn(async move {
484                match Self::parse_response(frozen_buf, handle) {
485                    Ok(table) => {
486                        let _ = req.channel.send(table).await;
487                    }
488                    Err(_e) => {
489                        async_node_error!(handle = handle, error = %_e, "parse error");
490                        // Channel drop will notify the caller
491                    }
492                }
493            });
494        } else {
495            async_node_warn!(handle = handle, "received response for unknown handle");
496        }
497
498        Ok(())
499    }
500
501    /// Parse response data (executed in a separate task)
502    fn parse_response(buf: bytes::Bytes, handle: i64) -> Result<VoltTable, VoltError> {
503        // Convert Bytes to ByteBuffer for parsing
504        let mut byte_buf = bytebuffer::ByteBuffer::from_bytes(&buf[..]);
505        let info = VoltResponseInfo::new(&mut byte_buf, handle)?;
506        let table = new_volt_table(&mut byte_buf, info)?;
507        Ok(table)
508    }
509
510    /// Cleanup all pending requests on shutdown or error
511    async fn cleanup_requests(
512        requests: &Arc<DashMap<i64, AsyncNetworkRequest>>,
513        pending_requests: &Arc<AtomicUsize>,
514        _reason: &str,
515    ) {
516        let pending_count = requests.len();
517
518        if pending_count > 0 {
519            async_node_warn!(
520                pending_count = pending_count,
521                reason = _reason,
522                "cleaning up pending requests"
523            );
524        }
525
526        // Clear the map (Drop will notify all waiters)
527        requests.clear();
528        pending_requests.store(0, Ordering::Relaxed);
529    }
530}
531
532/// Async wait for response result
533pub async fn async_block_for_result(
534    rx: &mut mpsc::Receiver<VoltTable>,
535) -> Result<VoltTable, VoltError> {
536    match rx.recv().await {
537        Some(mut table) => match table.has_error() {
538            None => Ok(table),
539            Some(err) => Err(err),
540        },
541        None => Err(VoltError::ConnectionNotAvailable),
542    }
543}
544
545/// Async wait for response result with timeout
546pub async fn async_block_for_result_with_timeout(
547    rx: &mut mpsc::Receiver<VoltTable>,
548    timeout_duration: Duration,
549) -> Result<VoltTable, VoltError> {
550    match timeout(timeout_duration, rx.recv()).await {
551        Ok(Some(mut table)) => match table.has_error() {
552            None => Ok(table),
553            Some(err) => Err(err),
554        },
555        Ok(None) => Err(VoltError::ConnectionNotAvailable),
556        Err(_) => Err(VoltError::Timeout),
557    }
558}
559
560// 单元测试
561#[cfg(test)]
562mod tests {
563    use super::*;
564
565    #[tokio::test]
566    async fn test_sequence_generation() {
567        let node = AsyncNode {
568            write_tx: mpsc::channel(1).0,
569            info: ConnInfo::default(),
570            requests: Arc::new(DashMap::new()),
571            stop: Arc::new(watch::channel(false).0),
572            counter: Arc::new(AtomicI64::new(1)),
573            pending_requests: Arc::new(AtomicUsize::new(0)),
574            task_handles: std::sync::Mutex::new(Vec::new()),
575        };
576
577        let seq1 = node.get_sequence();
578        let seq2 = node.get_sequence();
579        assert_eq!(seq2, seq1 + 1);
580    }
581
582    #[tokio::test]
583    async fn test_pending_count() {
584        let node = AsyncNode {
585            write_tx: mpsc::channel(1).0,
586            info: ConnInfo::default(),
587            requests: Arc::new(DashMap::new()),
588            stop: Arc::new(watch::channel(false).0),
589            counter: Arc::new(AtomicI64::new(1)),
590            pending_requests: Arc::new(AtomicUsize::new(5)),
591            task_handles: std::sync::Mutex::new(Vec::new()),
592        };
593        assert_eq!(node.pending_count(), 5);
594    }
595
596    #[tokio::test]
597    async fn test_async_block_for_result_with_timeout_expires() {
598        let (_tx, mut rx) = mpsc::channel::<VoltTable>(1);
599        let result = async_block_for_result_with_timeout(&mut rx, Duration::from_millis(50)).await;
600        assert!(matches!(result, Err(VoltError::Timeout)));
601    }
602}