Skip to main content

voltdb_client_rust/
async_node.rs

1#![cfg(feature = "tokio")]
2use std::fmt::{Debug, Formatter};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicI64, AtomicUsize, Ordering};
5use std::time::{Duration, Instant};
6
7use crate::encode::{Value, VoltError};
8use crate::node::{ConnInfo, NodeOpt};
9use crate::procedure_invocation::new_procedure_invocation;
10use crate::protocol::{PING_HANDLE, build_auth_message, parse_auth_response};
11use crate::response::VoltResponseInfo;
12use crate::table::{VoltTable, new_volt_table};
13use crate::volt_param;
14use byteorder::{BigEndian, ByteOrder};
15use bytes::{Buf, BytesMut};
16use dashmap::DashMap;
17use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
18use tokio::net::TcpStream;
19use tokio::sync::{mpsc, watch};
20use tokio::time::timeout;
21
22// ============================================================================
23// Logging macros - use tracing if available, otherwise no-op
24// ============================================================================
25
26#[cfg(feature = "tracing")]
27#[allow(unused_macros)]
28macro_rules! async_node_trace {
29    ($($arg:tt)*) => { tracing::trace!($($arg)*) };
30}
31#[cfg(not(feature = "tracing"))]
32#[allow(unused_macros)]
33macro_rules! async_node_trace {
34    ($($arg:tt)*) => {};
35}
36
37#[cfg(feature = "tracing")]
38macro_rules! async_node_debug {
39    ($($arg:tt)*) => { tracing::debug!($($arg)*) };
40}
41#[cfg(not(feature = "tracing"))]
42macro_rules! async_node_debug {
43    ($($arg:tt)*) => {};
44}
45
46#[cfg(feature = "tracing")]
47macro_rules! async_node_warn {
48    ($($arg:tt)*) => { tracing::warn!($($arg)*) };
49}
50#[cfg(not(feature = "tracing"))]
51macro_rules! async_node_warn {
52    ($($arg:tt)*) => {};
53}
54
55#[cfg(feature = "tracing")]
56macro_rules! async_node_error {
57    ($($arg:tt)*) => { tracing::error!($($arg)*) };
58}
59#[cfg(not(feature = "tracing"))]
60macro_rules! async_node_error {
61    ($($arg:tt)*) => {};
62}
63
64/// Configuration constants
65const MAX_MESSAGE_SIZE: usize = 50 * 1024 * 1024; // 50MB message size limit
66const WRITE_BUFFER_SIZE: usize = 1024; // Write queue capacity
67const BATCH_WRITE_THRESHOLD: usize = 8192; // Batch write threshold
68const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30); // Default timeout
69#[allow(dead_code)]
70const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(60); // TCP keepalive interval
71
72/// Write command for the writer task
73#[allow(dead_code)]
74enum WriteCommand {
75    Data(Vec<u8>),
76    Flush,
77}
78
79/// Async network request tracking
80#[allow(dead_code)]
81struct AsyncNetworkRequest {
82    handle: i64,
83    query: bool,
84    sync: bool,
85    num_bytes: i32,
86    channel: mpsc::Sender<VoltTable>,
87    created_at: Instant, // Used for timeout detection
88}
89
90impl Debug for AsyncNetworkRequest {
91    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
92        f.debug_struct("AsyncNetworkRequest")
93            .field("handle", &self.handle)
94            .field("query", &self.query)
95            .field("age_ms", &self.created_at.elapsed().as_millis())
96            .finish()
97    }
98}
99
100/// Async VoltDB connection node
101pub struct AsyncNode {
102    /// Write command sender channel
103    write_tx: mpsc::Sender<WriteCommand>,
104    /// Connection info from authentication
105    info: ConnInfo,
106    /// Request map (using DashMap to reduce lock contention)
107    requests: Arc<DashMap<i64, AsyncNetworkRequest>>,
108    /// Stop signal for background tasks
109    stop: Arc<watch::Sender<bool>>,
110    /// Request sequence number counter
111    counter: Arc<AtomicI64>,
112    /// Pending request count (used for load balancing)
113    pending_requests: Arc<AtomicUsize>,
114}
115
116impl Debug for AsyncNode {
117    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
118        f.debug_struct("AsyncNode")
119            .field(
120                "pending_requests",
121                &self.pending_requests.load(Ordering::Relaxed),
122            )
123            .field("total_requests", &self.requests.len())
124            .finish()
125    }
126}
127
128impl Drop for AsyncNode {
129    fn drop(&mut self) {
130        // Signal background tasks to stop
131        let _ = self.stop.send(true);
132    }
133}
134
135impl AsyncNode {
136    /// Create a new async connection to VoltDB server
137    pub async fn new(opt: NodeOpt) -> Result<AsyncNode, VoltError> {
138        let addr = format!("{}:{}", opt.ip_port.ip_host, opt.ip_port.port);
139
140        // Build authentication message
141        let auth_msg = build_auth_message(opt.user.as_deref(), opt.pass.as_deref())?;
142
143        // Async connect
144        let mut stream = TcpStream::connect(&addr).await?;
145
146        // TCP optimization configuration
147        stream.set_nodelay(true)?; // Disable Nagle algorithm to reduce latency
148        // if let Err(e) = stream.set_keepalive(Some(KEEPALIVE_INTERVAL)) {
149        //     eprintln!("Warning: Failed to set keepalive: {}", e);
150        // }
151
152        // Async authentication handshake
153        stream.write_all(&auth_msg).await?;
154        stream.flush().await?;
155
156        // Read authentication response
157        let mut len_buf = [0u8; 4];
158        stream.read_exact(&mut len_buf).await?;
159        let read = BigEndian::read_u32(&len_buf) as usize;
160
161        let mut all = vec![0; read];
162        stream.read_exact(&mut all).await?;
163
164        // Parse authentication response
165        let info = parse_auth_response(&all)?;
166
167        // Split into read and write halves
168        let (read_half, write_half) = tokio::io::split(stream);
169
170        // Create channels
171        let requests = Arc::new(DashMap::new());
172        let (stop_tx, stop_rx) = watch::channel(false);
173        let (write_tx, write_rx) = mpsc::channel(WRITE_BUFFER_SIZE);
174
175        let node = AsyncNode {
176            stop: Arc::new(stop_tx),
177            write_tx,
178            info,
179            requests: requests.clone(),
180            counter: Arc::new(AtomicI64::new(1)),
181            pending_requests: Arc::new(AtomicUsize::new(0)),
182        };
183
184        // Start background tasks
185        node.spawn_writer(write_half, write_rx, stop_rx.clone());
186        node.spawn_reader(read_half, stop_rx.clone());
187        node.spawn_timeout_checker(stop_rx);
188
189        Ok(node)
190    }
191
192    /// Get the next sequence number for request tracking
193    #[inline]
194    pub fn get_sequence(&self) -> i64 {
195        self.counter.fetch_add(1, Ordering::Relaxed)
196    }
197
198    /// Get the current pending request count (used for load balancing)
199    #[inline]
200    pub fn pending_count(&self) -> usize {
201        self.pending_requests.load(Ordering::Relaxed)
202    }
203
204    /// Get connection info from authentication
205    pub fn conn_info(&self) -> &ConnInfo {
206        &self.info
207    }
208
209    /// List all stored procedures available in the database
210    pub async fn list_procedures(&self) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
211        self.call_sp("@SystemCatalog", volt_param!("PROCEDURES"))
212            .await
213    }
214
215    /// Call a stored procedure with parameters
216    pub async fn call_sp(
217        &self,
218        query: &str,
219        param: Vec<&dyn Value>,
220    ) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
221        self.call_sp_with_timeout(query, param, DEFAULT_TIMEOUT)
222            .await
223    }
224
225    /// Call a stored procedure with custom timeout
226    pub async fn call_sp_with_timeout(
227        &self,
228        query: &str,
229        param: Vec<&dyn Value>,
230        _timeout_duration: Duration,
231    ) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
232        let req = self.get_sequence();
233        let mut proc = new_procedure_invocation(req, false, &param, query);
234
235        // Create response channel
236        let (tx, rx) = mpsc::channel(1);
237
238        let seq = AsyncNetworkRequest {
239            query: true,
240            handle: req,
241            num_bytes: proc.slen,
242            sync: true,
243            channel: tx,
244            created_at: Instant::now(),
245        };
246
247        // Insert into request map
248        self.requests.insert(req, seq);
249        self.pending_requests.fetch_add(1, Ordering::Relaxed);
250
251        // Send request data
252        let bs = proc.bytes();
253        self.write_tx
254            .send(WriteCommand::Data(bs))
255            .await
256            .map_err(|_| VoltError::connection_closed())?;
257
258        Ok(rx)
259    }
260
261    /// Upload a JAR file containing stored procedure classes
262    pub async fn upload_jar(&self, bs: Vec<u8>) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
263        self.call_sp("@UpdateClasses", volt_param!(bs, "")).await
264    }
265
266    /// Execute an ad-hoc SQL query
267    pub async fn query(&self, sql: &str) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
268        let mut zero_vec: Vec<&dyn Value> = Vec::new();
269        zero_vec.push(&sql);
270        self.call_sp("@AdHoc", zero_vec).await
271    }
272
273    /// Send a ping to keep the connection alive
274    pub async fn ping(&self) -> Result<(), VoltError> {
275        let zero_vec: Vec<&dyn Value> = Vec::new();
276        let mut proc = new_procedure_invocation(PING_HANDLE, false, &zero_vec, "@Ping");
277        let bs = proc.bytes();
278
279        self.write_tx
280            .send(WriteCommand::Data(bs))
281            .await
282            .map_err(|_| VoltError::connection_closed())?;
283
284        Ok(())
285    }
286
287    /// Shutdown the connection gracefully
288    pub async fn shutdown(&self) -> Result<(), VoltError> {
289        let _ = self.stop.send(true);
290        Ok(())
291    }
292
293    /// Spawn the writer task (supports batch write optimization)
294    fn spawn_writer(
295        &self,
296        mut write_half: WriteHalf<TcpStream>,
297        mut write_rx: mpsc::Receiver<WriteCommand>,
298        mut stop_rx: watch::Receiver<bool>,
299    ) {
300        tokio::spawn(async move {
301            let mut batch_buffer = Vec::with_capacity(BATCH_WRITE_THRESHOLD * 2);
302
303            loop {
304                tokio::select! {
305                    _ = stop_rx.changed() => {
306                        if *stop_rx.borrow() {
307                            break;
308                        }
309                    }
310                    cmd = write_rx.recv() => {
311                        match cmd {
312                            Some(WriteCommand::Data(bytes)) => {
313                                batch_buffer.extend_from_slice(&bytes);
314
315                                // Try to batch collect more data
316                                while batch_buffer.len() < BATCH_WRITE_THRESHOLD {
317                                    match write_rx.try_recv() {
318                                        Ok(WriteCommand::Data(more_bytes)) => {
319                                            batch_buffer.extend_from_slice(&more_bytes);
320                                        }
321                                        Ok(WriteCommand::Flush) => break,
322                                        Err(_) => break,
323                                    }
324                                }
325
326                                // Batch write
327                                if let Err(_e) = write_half.write_all(&batch_buffer).await {
328                                    async_node_error!(error = %_e, "write error");
329                                    break;
330                                }
331                                batch_buffer.clear();
332                            }
333                            Some(WriteCommand::Flush) => {
334                                if !batch_buffer.is_empty() {
335                                    if let Err(_e) = write_half.write_all(&batch_buffer).await {
336                                        async_node_error!(error = %_e, "flush error");
337                                        break;
338                                    }
339                                    batch_buffer.clear();
340                                }
341                                let _ = write_half.flush().await;
342                            }
343                            None => break,
344                        }
345                    }
346                }
347            }
348
349            async_node_debug!("writer task terminated");
350        });
351    }
352
353    /// Spawn the reader task for receiving responses
354    fn spawn_reader(&self, mut read_half: ReadHalf<TcpStream>, mut stop_rx: watch::Receiver<bool>) {
355        let requests = Arc::clone(&self.requests);
356        let pending_requests = Arc::clone(&self.pending_requests);
357
358        tokio::spawn(async move {
359            let reason = loop {
360                tokio::select! {
361                    _ = stop_rx.changed() => {
362                        if *stop_rx.borrow() {
363                            break "shutdown requested";
364                        }
365                    }
366                    result = Self::async_job(&mut read_half, &requests, &pending_requests) => {
367                        if let Err(_e) = result {
368                            if !*stop_rx.borrow() {
369                                async_node_error!(error = %_e, "read error");
370                            }
371                            break "connection error";
372                        }
373                    }
374                }
375            };
376
377            // Cleanup all pending requests
378            Self::cleanup_requests(&requests, &pending_requests, reason).await;
379        });
380    }
381
382    /// Spawn the timeout checker task for cleaning up stale requests
383    fn spawn_timeout_checker(&self, mut stop_rx: watch::Receiver<bool>) {
384        let requests = Arc::clone(&self.requests);
385        let pending_requests = Arc::clone(&self.pending_requests);
386
387        tokio::spawn(async move {
388            let mut interval = tokio::time::interval(Duration::from_secs(5));
389
390            loop {
391                tokio::select! {
392                    _ = stop_rx.changed() => {
393                        if *stop_rx.borrow() {
394                            break;
395                        }
396                    }
397                    _ = interval.tick() => {
398                        let now = Instant::now();
399                        let mut expired = Vec::new();
400
401                        // Find expired requests
402                        for entry in requests.iter() {
403                            let age = now.duration_since(entry.created_at);
404                            if age > DEFAULT_TIMEOUT * 2 {
405                                expired.push(*entry.key());
406                            }
407                        }
408
409                        // Cleanup expired requests
410                        for handle in expired {
411                            if let Some((_, _req)) = requests.remove(&handle) {
412                                pending_requests.fetch_sub(1, Ordering::Relaxed);
413                                async_node_warn!(
414                                    handle = handle,
415                                    elapsed = ?now.duration_since(_req.created_at),
416                                    "request timed out"
417                                );
418                                // Channel drop will notify the caller
419                            }
420                        }
421                    }
422                }
423            }
424        });
425    }
426
427    /// Read and process a single response from the server
428    async fn async_job(
429        tcp: &mut ReadHalf<TcpStream>,
430        requests: &Arc<DashMap<i64, AsyncNetworkRequest>>,
431        pending_requests: &Arc<AtomicUsize>,
432    ) -> Result<(), VoltError> {
433        // Read message length
434        let mut len_buf = [0u8; 4];
435        tcp.read_exact(&mut len_buf).await?;
436        let msg_len = BigEndian::read_u32(&len_buf) as usize;
437
438        // Safety check for message size
439        if msg_len > MAX_MESSAGE_SIZE {
440            return Err(VoltError::MessageTooLarge(msg_len));
441        }
442
443        if msg_len == 0 {
444            return Ok(());
445        }
446
447        // Use BytesMut to reduce memory copying
448        let mut buf = BytesMut::with_capacity(msg_len);
449        buf.resize(msg_len, 0);
450        tcp.read_exact(&mut buf).await?;
451
452        // Parse response header
453        let _ = buf.get_u8();
454        let handle = buf.get_i64();
455
456        // Ping response - just return
457        if handle == PING_HANDLE {
458            return Ok(());
459        }
460
461        // Route response to waiting caller
462        if let Some((_, req)) = requests.remove(&handle) {
463            pending_requests.fetch_sub(1, Ordering::Relaxed);
464
465            // Freeze buffer so it can be safely moved across tasks
466            let frozen_buf = buf.freeze();
467
468            // Parse in a separate task to avoid blocking the read loop
469            tokio::spawn(async move {
470                match Self::parse_response(frozen_buf, handle) {
471                    Ok(table) => {
472                        let _ = req.channel.send(table).await;
473                    }
474                    Err(_e) => {
475                        async_node_error!(handle = handle, error = %_e, "parse error");
476                        // Channel drop will notify the caller
477                    }
478                }
479            });
480        } else {
481            async_node_warn!(handle = handle, "received response for unknown handle");
482        }
483
484        Ok(())
485    }
486
487    /// Parse response data (executed in a separate task)
488    fn parse_response(buf: bytes::Bytes, handle: i64) -> Result<VoltTable, VoltError> {
489        // Convert Bytes to ByteBuffer for parsing
490        let mut byte_buf = bytebuffer::ByteBuffer::from_bytes(&buf[..]);
491        let info = VoltResponseInfo::new(&mut byte_buf, handle)?;
492        let table = new_volt_table(&mut byte_buf, info)?;
493        Ok(table)
494    }
495
496    /// Cleanup all pending requests on shutdown or error
497    async fn cleanup_requests(
498        requests: &Arc<DashMap<i64, AsyncNetworkRequest>>,
499        pending_requests: &Arc<AtomicUsize>,
500        _reason: &str,
501    ) {
502        let pending_count = requests.len();
503
504        if pending_count > 0 {
505            async_node_warn!(
506                pending_count = pending_count,
507                reason = _reason,
508                "cleaning up pending requests"
509            );
510        }
511
512        // Clear the map (Drop will notify all waiters)
513        requests.clear();
514        pending_requests.store(0, Ordering::Relaxed);
515    }
516}
517
518/// Async wait for response result
519pub async fn async_block_for_result(
520    rx: &mut mpsc::Receiver<VoltTable>,
521) -> Result<VoltTable, VoltError> {
522    match rx.recv().await {
523        Some(mut table) => match table.has_error() {
524            None => Ok(table),
525            Some(err) => Err(err),
526        },
527        None => Err(VoltError::ConnectionNotAvailable),
528    }
529}
530
531/// Async wait for response result with timeout
532pub async fn async_block_for_result_with_timeout(
533    rx: &mut mpsc::Receiver<VoltTable>,
534    timeout_duration: Duration,
535) -> Result<VoltTable, VoltError> {
536    match timeout(timeout_duration, rx.recv()).await {
537        Ok(Some(mut table)) => match table.has_error() {
538            None => Ok(table),
539            Some(err) => Err(err),
540        },
541        Ok(None) => Err(VoltError::ConnectionNotAvailable),
542        Err(_) => Err(VoltError::Timeout),
543    }
544}
545
546/// VoltError extension methods for async operations
547impl VoltError {
548    pub fn message_too_large(size: usize) -> Self {
549        VoltError::MessageTooLarge(size)
550    }
551
552    pub fn connection_closed() -> Self {
553        VoltError::ConnectionClosed
554    }
555
556    pub fn timeout() -> Self {
557        VoltError::Timeout
558    }
559}
560
561// 单元测试
562#[cfg(test)]
563mod tests {
564    use super::*;
565
566    #[tokio::test]
567    async fn test_sequence_generation() {
568        let node = AsyncNode {
569            write_tx: mpsc::channel(1).0,
570            info: ConnInfo::default(),
571            requests: Arc::new(DashMap::new()),
572            stop: Arc::new(watch::channel(false).0),
573            counter: Arc::new(AtomicI64::new(1)),
574            pending_requests: Arc::new(AtomicUsize::new(0)),
575        };
576
577        let seq1 = node.get_sequence();
578        let seq2 = node.get_sequence();
579        assert_eq!(seq2, seq1 + 1);
580    }
581
582    #[tokio::test]
583    async fn test_pending_count() {
584        let node = AsyncNode {
585            write_tx: mpsc::channel(1).0,
586            info: ConnInfo::default(),
587            requests: Arc::new(DashMap::new()),
588            stop: Arc::new(watch::channel(false).0),
589            counter: Arc::new(AtomicI64::new(1)),
590            pending_requests: Arc::new(AtomicUsize::new(5)),
591        };
592        assert_eq!(node.pending_count(), 5);
593    }
594}