voltdb_client_rust/
async_node.rs

1#![cfg(feature = "tokio")]
2use std::fmt::{Debug, Formatter};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicI64, AtomicUsize, Ordering};
5use std::time::{Duration, Instant};
6
7use crate::encode::{Value, VoltError};
8use crate::node::{ConnInfo, NodeOpt};
9use crate::procedure_invocation::new_procedure_invocation;
10use crate::protocol::{PING_HANDLE, build_auth_message, parse_auth_response};
11use crate::response::VoltResponseInfo;
12use crate::table::{VoltTable, new_volt_table};
13use crate::volt_param;
14use byteorder::{BigEndian, ByteOrder};
15use bytes::{Buf, BytesMut};
16use dashmap::DashMap;
17use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
18use tokio::net::TcpStream;
19use tokio::sync::{mpsc, watch};
20use tokio::time::timeout;
21
22/// 配置常量
23const MAX_MESSAGE_SIZE: usize = 50 * 1024 * 1024; // 50MB消息上限
24const WRITE_BUFFER_SIZE: usize = 1024; // 写入队列容量
25const BATCH_WRITE_THRESHOLD: usize = 8192; // 批量写入阈值
26const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30); // 默认超时
27#[allow(dead_code)]
28const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(60); // TCP保活间隔
29
30/// 写入命令
31#[allow(dead_code)]
32enum WriteCommand {
33    Data(Vec<u8>),
34    Flush,
35}
36
37/// 异步网络请求跟踪
38#[allow(dead_code)]
39struct AsyncNetworkRequest {
40    handle: i64,
41    query: bool,
42    sync: bool,
43    num_bytes: i32,
44    channel: mpsc::Sender<VoltTable>,
45    created_at: Instant, // 用于超时检测
46}
47
48impl Debug for AsyncNetworkRequest {
49    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
50        f.debug_struct("AsyncNetworkRequest")
51            .field("handle", &self.handle)
52            .field("query", &self.query)
53            .field("age_ms", &self.created_at.elapsed().as_millis())
54            .finish()
55    }
56}
57
58/// 异步 VoltDB 连接节点
59pub struct AsyncNode {
60    /// 写入命令发送通道
61    write_tx: mpsc::Sender<WriteCommand>,
62    /// 连接信息
63    info: ConnInfo,
64    /// 请求映射表 (使用 DashMap 减少锁竞争)
65    requests: Arc<DashMap<i64, AsyncNetworkRequest>>,
66    /// 停止信号
67    stop: Arc<watch::Sender<bool>>,
68    /// 请求序列号计数器
69    counter: Arc<AtomicI64>,
70    /// 待处理请求数 (用于负载均衡)
71    pending_requests: Arc<AtomicUsize>,
72}
73
74impl Debug for AsyncNode {
75    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
76        f.debug_struct("AsyncNode")
77            .field(
78                "pending_requests",
79                &self.pending_requests.load(Ordering::Relaxed),
80            )
81            .field("total_requests", &self.requests.len())
82            .finish()
83    }
84}
85
86impl AsyncNode {
87    /// 创建新的异步连接
88    pub async fn new(opt: NodeOpt) -> Result<AsyncNode, VoltError> {
89        let addr = format!("{}:{}", opt.ip_port.ip_host, opt.ip_port.port);
90
91        // 构建认证消息
92        let auth_msg = build_auth_message(opt.user.as_deref(), opt.pass.as_deref())?;
93
94        // 异步连接
95        let mut stream = TcpStream::connect(&addr).await?;
96
97        // TCP 优化配置
98        stream.set_nodelay(true)?; // 禁用 Nagle 算法,降低延迟
99        // if let Err(e) = stream.set_keepalive(Some(KEEPALIVE_INTERVAL)) {
100        //     eprintln!("Warning: Failed to set keepalive: {}", e);
101        // }
102
103        // 异步认证握手
104        stream.write_all(&auth_msg).await?;
105        stream.flush().await?;
106
107        // 读取认证响应
108        let mut len_buf = [0u8; 4];
109        stream.read_exact(&mut len_buf).await?;
110        let read = BigEndian::read_u32(&len_buf) as usize;
111
112        let mut all = vec![0; read];
113        stream.read_exact(&mut all).await?;
114
115        // 解析认证响应
116        let info = parse_auth_response(&all)?;
117
118        // 拆分读写流
119        let (read_half, write_half) = tokio::io::split(stream);
120
121        // 创建通道
122        let requests = Arc::new(DashMap::new());
123        let (stop_tx, stop_rx) = watch::channel(false);
124        let (write_tx, write_rx) = mpsc::channel(WRITE_BUFFER_SIZE);
125
126        let node = AsyncNode {
127            stop: Arc::new(stop_tx),
128            write_tx,
129            info,
130            requests: requests.clone(),
131            counter: Arc::new(AtomicI64::new(1)),
132            pending_requests: Arc::new(AtomicUsize::new(0)),
133        };
134
135        // 启动后台任务
136        node.spawn_writer(write_half, write_rx, stop_rx.clone());
137        node.spawn_reader(read_half, stop_rx.clone());
138        node.spawn_timeout_checker(stop_rx);
139
140        Ok(node)
141    }
142
143    /// 获取下一个序列号
144    #[inline]
145    pub fn get_sequence(&self) -> i64 {
146        self.counter.fetch_add(1, Ordering::Relaxed)
147    }
148
149    /// 获取当前待处理请求数 (用于负载均衡)
150    #[inline]
151    pub fn pending_count(&self) -> usize {
152        self.pending_requests.load(Ordering::Relaxed)
153    }
154
155    /// 获取连接信息
156    pub fn conn_info(&self) -> &ConnInfo {
157        &self.info
158    }
159
160    /// 列出所有存储过程
161    pub async fn list_procedures(&self) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
162        self.call_sp("@SystemCatalog", volt_param!("PROCEDURES"))
163            .await
164    }
165
166    /// 调用存储过程
167    pub async fn call_sp(
168        &self,
169        query: &str,
170        param: Vec<&dyn Value>,
171    ) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
172        self.call_sp_with_timeout(query, param, DEFAULT_TIMEOUT)
173            .await
174    }
175
176    /// 带超时的存储过程调用
177    pub async fn call_sp_with_timeout(
178        &self,
179        query: &str,
180        param: Vec<&dyn Value>,
181        _timeout_duration: Duration,
182    ) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
183        let req = self.get_sequence();
184        let mut proc = new_procedure_invocation(req, false, &param, query);
185
186        // 创建响应通道
187        let (tx, rx) = mpsc::channel(1);
188
189        let seq = AsyncNetworkRequest {
190            query: true,
191            handle: req,
192            num_bytes: proc.slen,
193            sync: true,
194            channel: tx,
195            created_at: Instant::now(),
196        };
197
198        // 插入请求映射
199        self.requests.insert(req, seq);
200        self.pending_requests.fetch_add(1, Ordering::Relaxed);
201
202        // 发送请求数据
203        let bs = proc.bytes();
204        self.write_tx
205            .send(WriteCommand::Data(bs))
206            .await
207            .map_err(|_| VoltError::connection_closed())?;
208
209        Ok(rx)
210    }
211
212    /// 上传 JAR 文件
213    pub async fn upload_jar(&self, bs: Vec<u8>) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
214        self.call_sp("@UpdateClasses", volt_param!(bs, "")).await
215    }
216
217    /// 执行 Ad-Hoc SQL 查询
218    pub async fn query(&self, sql: &str) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
219        let mut zero_vec: Vec<&dyn Value> = Vec::new();
220        zero_vec.push(&sql);
221        self.call_sp("@AdHoc", zero_vec).await
222    }
223
224    /// 发送 Ping 保持连接
225    pub async fn ping(&self) -> Result<(), VoltError> {
226        let zero_vec: Vec<&dyn Value> = Vec::new();
227        let mut proc = new_procedure_invocation(PING_HANDLE, false, &zero_vec, "@Ping");
228        let bs = proc.bytes();
229
230        self.write_tx
231            .send(WriteCommand::Data(bs))
232            .await
233            .map_err(|_| VoltError::connection_closed())?;
234
235        Ok(())
236    }
237
238    /// 关闭连接
239    pub async fn shutdown(&self) -> Result<(), VoltError> {
240        let _ = self.stop.send(true);
241        Ok(())
242    }
243
244    /// 启动写入任务 (支持批量写入优化)
245    fn spawn_writer(
246        &self,
247        mut write_half: WriteHalf<TcpStream>,
248        mut write_rx: mpsc::Receiver<WriteCommand>,
249        mut stop_rx: watch::Receiver<bool>,
250    ) {
251        tokio::spawn(async move {
252            let mut batch_buffer = Vec::with_capacity(BATCH_WRITE_THRESHOLD * 2);
253
254            loop {
255                tokio::select! {
256                    _ = stop_rx.changed() => {
257                        if *stop_rx.borrow() {
258                            break;
259                        }
260                    }
261                    cmd = write_rx.recv() => {
262                        match cmd {
263                            Some(WriteCommand::Data(bytes)) => {
264                                batch_buffer.extend_from_slice(&bytes);
265
266                                // 尝试批量收集更多数据
267                                while batch_buffer.len() < BATCH_WRITE_THRESHOLD {
268                                    match write_rx.try_recv() {
269                                        Ok(WriteCommand::Data(more_bytes)) => {
270                                            batch_buffer.extend_from_slice(&more_bytes);
271                                        }
272                                        Ok(WriteCommand::Flush) => break,
273                                        Err(_) => break,
274                                    }
275                                }
276
277                                // 批量写入
278                                if let Err(e) = write_half.write_all(&batch_buffer).await {
279                                    eprintln!("Write error: {}", e);
280                                    break;
281                                }
282                                batch_buffer.clear();
283                            }
284                            Some(WriteCommand::Flush) => {
285                                if !batch_buffer.is_empty() {
286                                    if let Err(e) = write_half.write_all(&batch_buffer).await {
287                                        eprintln!("Flush error: {}", e);
288                                        break;
289                                    }
290                                    batch_buffer.clear();
291                                }
292                                let _ = write_half.flush().await;
293                            }
294                            None => break,
295                        }
296                    }
297                }
298            }
299
300            eprintln!("Writer task terminated");
301        });
302    }
303
304    /// 启动读取任务
305    fn spawn_reader(&self, mut read_half: ReadHalf<TcpStream>, mut stop_rx: watch::Receiver<bool>) {
306        let requests = Arc::clone(&self.requests);
307        let pending_requests = Arc::clone(&self.pending_requests);
308
309        tokio::spawn(async move {
310            let reason = loop {
311                tokio::select! {
312                    _ = stop_rx.changed() => {
313                        if *stop_rx.borrow() {
314                            break "shutdown requested";
315                        }
316                    }
317                    result = Self::async_job(&mut read_half, &requests, &pending_requests) => {
318                        if let Err(e) = result {
319                            if !*stop_rx.borrow() {
320                                eprintln!("Read error: {}", e);
321                            }
322                            break "connection error";
323                        }
324                    }
325                }
326            };
327
328            // 清理所有待处理请求
329            Self::cleanup_requests(&requests, &pending_requests, reason).await;
330        });
331    }
332
333    /// 启动超时检查任务
334    fn spawn_timeout_checker(&self, mut stop_rx: watch::Receiver<bool>) {
335        let requests = Arc::clone(&self.requests);
336        let pending_requests = Arc::clone(&self.pending_requests);
337
338        tokio::spawn(async move {
339            let mut interval = tokio::time::interval(Duration::from_secs(5));
340
341            loop {
342                tokio::select! {
343                    _ = stop_rx.changed() => {
344                        if *stop_rx.borrow() {
345                            break;
346                        }
347                    }
348                    _ = interval.tick() => {
349                        let now = Instant::now();
350                        let mut expired = Vec::new();
351
352                        // 查找超时请求
353                        for entry in requests.iter() {
354                            let age = now.duration_since(entry.created_at);
355                            if age > DEFAULT_TIMEOUT * 2 {
356                                expired.push(*entry.key());
357                            }
358                        }
359
360                        // 清理超时请求
361                        for handle in expired {
362                            if let Some((_, req)) = requests.remove(&handle) {
363                                pending_requests.fetch_sub(1, Ordering::Relaxed);
364                                eprintln!("Request {} timed out after {:?}", handle,
365                                    now.duration_since(req.created_at));
366                                // channel drop 会通知调用者
367                            }
368                        }
369                    }
370                }
371            }
372        });
373    }
374
375    /// 读取并处理单个响应
376    async fn async_job(
377        tcp: &mut ReadHalf<TcpStream>,
378        requests: &Arc<DashMap<i64, AsyncNetworkRequest>>,
379        pending_requests: &Arc<AtomicUsize>,
380    ) -> Result<(), VoltError> {
381        // 读取消息长度
382        let mut len_buf = [0u8; 4];
383        tcp.read_exact(&mut len_buf).await?;
384        let msg_len = BigEndian::read_u32(&len_buf) as usize;
385
386        // 安全检查
387        if msg_len > MAX_MESSAGE_SIZE {
388            return Err(VoltError::MessageTooLarge(msg_len));
389        }
390
391        if msg_len == 0 {
392            return Ok(());
393        }
394
395        // 使用 BytesMut 减少内存拷贝
396        let mut buf = BytesMut::with_capacity(msg_len);
397        buf.resize(msg_len, 0);
398        tcp.read_exact(&mut buf).await?;
399
400        // 解析响应头
401        let _ = buf.get_u8();
402        let handle = buf.get_i64();
403
404        // Ping 响应直接返回
405        if handle == PING_HANDLE {
406            return Ok(());
407        }
408
409        // 路由响应到等待的调用者
410        if let Some((_, req)) = requests.remove(&handle) {
411            pending_requests.fetch_sub(1, Ordering::Relaxed);
412
413            // 冻结 buffer 以便安全地跨任务移动
414            let frozen_buf = buf.freeze();
415
416            // 在独立任务中解析,避免阻塞读取循环
417            tokio::spawn(async move {
418                match Self::parse_response(frozen_buf, handle) {
419                    Ok(table) => {
420                        let _ = req.channel.send(table).await;
421                    }
422                    Err(e) => {
423                        eprintln!("Parse error for handle {}: {}", handle, e);
424                        // channel drop 会通知调用者
425                    }
426                }
427            });
428        } else {
429            eprintln!("Received response for unknown handle: {}", handle);
430        }
431
432        Ok(())
433    }
434
435    /// 解析响应数据 (在独立任务中执行)
436    fn parse_response(buf: bytes::Bytes, handle: i64) -> Result<VoltTable, VoltError> {
437        // 将 Bytes 转换为 ByteBuffer 进行解析
438        let mut byte_buf = bytebuffer::ByteBuffer::from_bytes(&buf[..]);
439        let info = VoltResponseInfo::new(&mut byte_buf, handle)?;
440        let table = new_volt_table(&mut byte_buf, info)?;
441        Ok(table)
442    }
443
444    /// 清理所有待处理请求
445    async fn cleanup_requests(
446        requests: &Arc<DashMap<i64, AsyncNetworkRequest>>,
447        pending_requests: &Arc<AtomicUsize>,
448        reason: &str,
449    ) {
450        let pending_count = requests.len();
451
452        if pending_count > 0 {
453            eprintln!("Cleaning up {} pending requests: {}", pending_count, reason);
454        }
455
456        // 清空映射表 (Drop 会通知所有等待者)
457        requests.clear();
458        pending_requests.store(0, Ordering::Relaxed);
459    }
460}
461
462/// 异步等待响应结果
463pub async fn async_block_for_result(
464    rx: &mut mpsc::Receiver<VoltTable>,
465) -> Result<VoltTable, VoltError> {
466    match rx.recv().await {
467        Some(table) => Ok(table), // 直接返回 table,不调用 .has_error()
468        None => Err(VoltError::ConnectionNotAvailable),
469    }
470}
471
472/// 带超时的异步等待
473pub async fn async_block_for_result_with_timeout(
474    rx: &mut mpsc::Receiver<VoltTable>,
475    timeout_duration: Duration,
476) -> Result<VoltTable, VoltError> {
477    match timeout(timeout_duration, rx.recv()).await {
478        Ok(Some(mut table)) => match table.has_error() {
479            None => Ok(table),
480            Some(err) => Err(err),
481        },
482        Ok(None) => Err(VoltError::ConnectionNotAvailable),
483        Err(_) => Err(VoltError::Timeout),
484    }
485}
486
487/// VoltError 扩展 (需要在 encode.rs 中添加)
488impl VoltError {
489    pub fn message_too_large(size: usize) -> Self {
490        VoltError::MessageTooLarge(size)
491    }
492
493    pub fn connection_closed() -> Self {
494        VoltError::ConnectionClosed
495    }
496
497    pub fn timeout() -> Self {
498        VoltError::Timeout
499    }
500}
501
502// 单元测试
503#[cfg(test)]
504mod tests {
505    use super::*;
506
507    #[tokio::test]
508    async fn test_sequence_generation() {
509        let node = AsyncNode {
510            write_tx: mpsc::channel(1).0,
511            info: ConnInfo::default(),
512            requests: Arc::new(DashMap::new()),
513            stop: Arc::new(watch::channel(false).0),
514            counter: Arc::new(AtomicI64::new(1)),
515            pending_requests: Arc::new(AtomicUsize::new(0)),
516        };
517
518        let seq1 = node.get_sequence();
519        let seq2 = node.get_sequence();
520        assert_eq!(seq2, seq1 + 1);
521    }
522
523    #[tokio::test]
524    async fn test_pending_count() {
525        let node = AsyncNode {
526            write_tx: mpsc::channel(1).0,
527            info: ConnInfo::default(),
528            requests: Arc::new(DashMap::new()),
529            stop: Arc::new(watch::channel(false).0),
530            counter: Arc::new(AtomicI64::new(1)),
531            pending_requests: Arc::new(AtomicUsize::new(5)),
532        };
533        assert_eq!(node.pending_count(), 5);
534    }
535}