1#![cfg(feature = "tokio")]
2use std::fmt::{Debug, Formatter};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicI64, AtomicUsize, Ordering};
5use std::time::{Duration, Instant};
6
7use crate::encode::{Value, VoltError};
8use crate::node::{ConnInfo, NodeOpt};
9use crate::procedure_invocation::new_procedure_invocation;
10use crate::protocol::{PING_HANDLE, build_auth_message, parse_auth_response};
11use crate::response::VoltResponseInfo;
12use crate::table::{VoltTable, new_volt_table};
13use crate::volt_param;
14use byteorder::{BigEndian, ByteOrder};
15use bytes::{Buf, BytesMut};
16use dashmap::DashMap;
17use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
18use tokio::net::TcpStream;
19use tokio::sync::{mpsc, watch};
20use tokio::time::timeout;
21
22#[cfg(feature = "tracing")]
27#[allow(unused_macros)]
28macro_rules! async_node_trace {
29 ($($arg:tt)*) => { tracing::trace!($($arg)*) };
30}
31#[cfg(not(feature = "tracing"))]
32#[allow(unused_macros)]
33macro_rules! async_node_trace {
34 ($($arg:tt)*) => {};
35}
36
37#[cfg(feature = "tracing")]
38macro_rules! async_node_debug {
39 ($($arg:tt)*) => { tracing::debug!($($arg)*) };
40}
41#[cfg(not(feature = "tracing"))]
42macro_rules! async_node_debug {
43 ($($arg:tt)*) => {};
44}
45
46#[cfg(feature = "tracing")]
47macro_rules! async_node_warn {
48 ($($arg:tt)*) => { tracing::warn!($($arg)*) };
49}
50#[cfg(not(feature = "tracing"))]
51macro_rules! async_node_warn {
52 ($($arg:tt)*) => {};
53}
54
55#[cfg(feature = "tracing")]
56macro_rules! async_node_error {
57 ($($arg:tt)*) => { tracing::error!($($arg)*) };
58}
59#[cfg(not(feature = "tracing"))]
60macro_rules! async_node_error {
61 ($($arg:tt)*) => {};
62}
63
64const MAX_MESSAGE_SIZE: usize = 50 * 1024 * 1024; const WRITE_BUFFER_SIZE: usize = 1024; const BATCH_WRITE_THRESHOLD: usize = 8192; const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30); #[allow(dead_code)]
70const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(60); #[allow(dead_code)]
74enum WriteCommand {
75 Data(Vec<u8>),
76 Flush,
77}
78
79#[allow(dead_code)]
81struct AsyncNetworkRequest {
82 handle: i64,
83 query: bool,
84 sync: bool,
85 num_bytes: i32,
86 channel: mpsc::Sender<VoltTable>,
87 created_at: Instant, }
89
90impl Debug for AsyncNetworkRequest {
91 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
92 f.debug_struct("AsyncNetworkRequest")
93 .field("handle", &self.handle)
94 .field("query", &self.query)
95 .field("age_ms", &self.created_at.elapsed().as_millis())
96 .finish()
97 }
98}
99
100pub struct AsyncNode {
102 write_tx: mpsc::Sender<WriteCommand>,
104 info: ConnInfo,
106 requests: Arc<DashMap<i64, AsyncNetworkRequest>>,
108 stop: Arc<watch::Sender<bool>>,
110 counter: Arc<AtomicI64>,
112 pending_requests: Arc<AtomicUsize>,
114 task_handles: std::sync::Mutex<Vec<tokio::task::JoinHandle<()>>>,
116}
117
118impl Debug for AsyncNode {
119 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
120 f.debug_struct("AsyncNode")
121 .field(
122 "pending_requests",
123 &self.pending_requests.load(Ordering::Relaxed),
124 )
125 .field("total_requests", &self.requests.len())
126 .finish()
127 }
128}
129
130impl Drop for AsyncNode {
131 fn drop(&mut self) {
132 let _ = self.stop.send(true);
134 }
135}
136
137impl AsyncNode {
138 pub async fn new(opt: NodeOpt) -> Result<AsyncNode, VoltError> {
140 let addr = format!("{}:{}", opt.ip_port.ip_host, opt.ip_port.port);
141
142 let auth_msg = build_auth_message(opt.user.as_deref(), opt.pass.as_deref())?;
144
145 let mut stream = TcpStream::connect(&addr).await?;
147
148 stream.set_nodelay(true)?; stream.write_all(&auth_msg).await?;
156 stream.flush().await?;
157
158 let mut len_buf = [0u8; 4];
160 stream.read_exact(&mut len_buf).await?;
161 let read = BigEndian::read_u32(&len_buf) as usize;
162
163 let mut all = vec![0; read];
164 stream.read_exact(&mut all).await?;
165
166 let info = parse_auth_response(&all)?;
168
169 let (read_half, write_half) = tokio::io::split(stream);
171
172 let requests = Arc::new(DashMap::new());
174 let (stop_tx, stop_rx) = watch::channel(false);
175 let (write_tx, write_rx) = mpsc::channel(WRITE_BUFFER_SIZE);
176
177 let node = AsyncNode {
178 stop: Arc::new(stop_tx),
179 write_tx,
180 info,
181 requests: requests.clone(),
182 counter: Arc::new(AtomicI64::new(1)),
183 pending_requests: Arc::new(AtomicUsize::new(0)),
184 task_handles: std::sync::Mutex::new(Vec::with_capacity(3)),
185 };
186
187 node.spawn_writer(write_half, write_rx, stop_rx.clone());
189 node.spawn_reader(read_half, stop_rx.clone());
190 node.spawn_timeout_checker(stop_rx);
191
192 Ok(node)
193 }
194
195 #[inline]
197 pub fn get_sequence(&self) -> i64 {
198 self.counter.fetch_add(1, Ordering::Relaxed)
199 }
200
201 #[inline]
203 pub fn pending_count(&self) -> usize {
204 self.pending_requests.load(Ordering::Relaxed)
205 }
206
207 pub fn conn_info(&self) -> &ConnInfo {
209 &self.info
210 }
211
212 pub async fn list_procedures(&self) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
214 self.call_sp("@SystemCatalog", volt_param!("PROCEDURES"))
215 .await
216 }
217
218 pub async fn call_sp(
222 &self,
223 query: &str,
224 param: Vec<&dyn Value>,
225 ) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
226 let req = self.get_sequence();
227 let mut proc = new_procedure_invocation(req, false, ¶m, query);
228
229 let (tx, rx) = mpsc::channel(1);
231
232 let seq = AsyncNetworkRequest {
233 query: true,
234 handle: req,
235 num_bytes: proc.slen,
236 sync: true,
237 channel: tx,
238 created_at: Instant::now(),
239 };
240
241 self.requests.insert(req, seq);
243 self.pending_requests.fetch_add(1, Ordering::Relaxed);
244
245 let bs = proc.bytes();
247 self.write_tx
248 .send(WriteCommand::Data(bs))
249 .await
250 .map_err(|_| VoltError::connection_closed())?;
251
252 Ok(rx)
253 }
254
255 pub async fn call_sp_with_timeout(
258 &self,
259 query: &str,
260 param: Vec<&dyn Value>,
261 timeout_duration: Duration,
262 ) -> Result<VoltTable, VoltError> {
263 let mut rx = self.call_sp(query, param).await?;
264 async_block_for_result_with_timeout(&mut rx, timeout_duration).await
265 }
266
267 pub async fn upload_jar(&self, bs: Vec<u8>) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
269 self.call_sp("@UpdateClasses", volt_param!(bs, "")).await
270 }
271
272 pub async fn query(&self, sql: &str) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
274 let mut zero_vec: Vec<&dyn Value> = Vec::new();
275 zero_vec.push(&sql);
276 self.call_sp("@AdHoc", zero_vec).await
277 }
278
279 pub async fn ping(&self) -> Result<(), VoltError> {
281 let zero_vec: Vec<&dyn Value> = Vec::new();
282 let mut proc = new_procedure_invocation(PING_HANDLE, false, &zero_vec, "@Ping");
283 let bs = proc.bytes();
284
285 self.write_tx
286 .send(WriteCommand::Data(bs))
287 .await
288 .map_err(|_| VoltError::connection_closed())?;
289
290 Ok(())
291 }
292
293 pub async fn shutdown(&self) -> Result<(), VoltError> {
296 let _ = self.stop.send(true);
297 let handles: Vec<_> = self.task_handles.lock().unwrap().drain(..).collect();
298 for handle in handles {
299 let _ = tokio::time::timeout(Duration::from_secs(5), handle).await;
300 }
301 Ok(())
302 }
303
304 fn spawn_writer(
306 &self,
307 mut write_half: WriteHalf<TcpStream>,
308 mut write_rx: mpsc::Receiver<WriteCommand>,
309 mut stop_rx: watch::Receiver<bool>,
310 ) {
311 let handle = tokio::spawn(async move {
312 let mut batch_buffer = Vec::with_capacity(BATCH_WRITE_THRESHOLD * 2);
313
314 loop {
315 tokio::select! {
316 _ = stop_rx.changed() => {
317 if *stop_rx.borrow() {
318 break;
319 }
320 }
321 cmd = write_rx.recv() => {
322 match cmd {
323 Some(WriteCommand::Data(bytes)) => {
324 batch_buffer.extend_from_slice(&bytes);
325
326 while batch_buffer.len() < BATCH_WRITE_THRESHOLD {
328 match write_rx.try_recv() {
329 Ok(WriteCommand::Data(more_bytes)) => {
330 batch_buffer.extend_from_slice(&more_bytes);
331 }
332 Ok(WriteCommand::Flush) => break,
333 Err(_) => break,
334 }
335 }
336
337 if let Err(_e) = write_half.write_all(&batch_buffer).await {
339 async_node_error!(error = %_e, "write error");
340 break;
341 }
342 batch_buffer.clear();
343 }
344 Some(WriteCommand::Flush) => {
345 if !batch_buffer.is_empty() {
346 if let Err(_e) = write_half.write_all(&batch_buffer).await {
347 async_node_error!(error = %_e, "flush error");
348 break;
349 }
350 batch_buffer.clear();
351 }
352 let _ = write_half.flush().await;
353 }
354 None => break,
355 }
356 }
357 }
358 }
359
360 async_node_debug!("writer task terminated");
361 });
362 self.task_handles.lock().unwrap().push(handle);
363 }
364
365 fn spawn_reader(&self, mut read_half: ReadHalf<TcpStream>, mut stop_rx: watch::Receiver<bool>) {
367 let requests = Arc::clone(&self.requests);
368 let pending_requests = Arc::clone(&self.pending_requests);
369
370 let handle = tokio::spawn(async move {
371 let reason = loop {
372 tokio::select! {
373 _ = stop_rx.changed() => {
374 if *stop_rx.borrow() {
375 break "shutdown requested";
376 }
377 }
378 result = Self::async_job(&mut read_half, &requests, &pending_requests) => {
379 if let Err(_e) = result {
380 if !*stop_rx.borrow() {
381 async_node_error!(error = %_e, "read error");
382 }
383 break "connection error";
384 }
385 }
386 }
387 };
388
389 Self::cleanup_requests(&requests, &pending_requests, reason).await;
391 });
392 self.task_handles.lock().unwrap().push(handle);
393 }
394
395 fn spawn_timeout_checker(&self, mut stop_rx: watch::Receiver<bool>) {
397 let requests = Arc::clone(&self.requests);
398 let pending_requests = Arc::clone(&self.pending_requests);
399
400 let handle = tokio::spawn(async move {
401 let mut interval = tokio::time::interval(Duration::from_secs(5));
402
403 loop {
404 tokio::select! {
405 _ = stop_rx.changed() => {
406 if *stop_rx.borrow() {
407 break;
408 }
409 }
410 _ = interval.tick() => {
411 let now = Instant::now();
412 let mut expired = Vec::new();
413
414 for entry in requests.iter() {
416 let age = now.duration_since(entry.created_at);
417 if age > DEFAULT_TIMEOUT {
418 expired.push(*entry.key());
419 }
420 }
421
422 for handle in expired {
424 if let Some((_, _req)) = requests.remove(&handle) {
425 pending_requests.fetch_sub(1, Ordering::Relaxed);
426 async_node_warn!(
427 handle = handle,
428 elapsed = ?now.duration_since(_req.created_at),
429 "request timed out"
430 );
431 }
433 }
434 }
435 }
436 }
437 });
438 self.task_handles.lock().unwrap().push(handle);
439 }
440
441 async fn async_job(
443 tcp: &mut ReadHalf<TcpStream>,
444 requests: &Arc<DashMap<i64, AsyncNetworkRequest>>,
445 pending_requests: &Arc<AtomicUsize>,
446 ) -> Result<(), VoltError> {
447 let mut len_buf = [0u8; 4];
449 tcp.read_exact(&mut len_buf).await?;
450 let msg_len = BigEndian::read_u32(&len_buf) as usize;
451
452 if msg_len > MAX_MESSAGE_SIZE {
454 return Err(VoltError::MessageTooLarge(msg_len));
455 }
456
457 if msg_len == 0 {
458 return Ok(());
459 }
460
461 let mut buf = BytesMut::with_capacity(msg_len);
463 buf.resize(msg_len, 0);
464 tcp.read_exact(&mut buf).await?;
465
466 let _ = buf.get_u8();
468 let handle = buf.get_i64();
469
470 if handle == PING_HANDLE {
472 return Ok(());
473 }
474
475 if let Some((_, req)) = requests.remove(&handle) {
477 pending_requests.fetch_sub(1, Ordering::Relaxed);
478
479 let frozen_buf = buf.freeze();
481
482 tokio::spawn(async move {
484 match Self::parse_response(frozen_buf, handle) {
485 Ok(table) => {
486 let _ = req.channel.send(table).await;
487 }
488 Err(_e) => {
489 async_node_error!(handle = handle, error = %_e, "parse error");
490 }
492 }
493 });
494 } else {
495 async_node_warn!(handle = handle, "received response for unknown handle");
496 }
497
498 Ok(())
499 }
500
501 fn parse_response(buf: bytes::Bytes, handle: i64) -> Result<VoltTable, VoltError> {
503 let mut byte_buf = bytebuffer::ByteBuffer::from_bytes(&buf[..]);
505 let info = VoltResponseInfo::new(&mut byte_buf, handle)?;
506 let table = new_volt_table(&mut byte_buf, info)?;
507 Ok(table)
508 }
509
510 async fn cleanup_requests(
512 requests: &Arc<DashMap<i64, AsyncNetworkRequest>>,
513 pending_requests: &Arc<AtomicUsize>,
514 _reason: &str,
515 ) {
516 let pending_count = requests.len();
517
518 if pending_count > 0 {
519 async_node_warn!(
520 pending_count = pending_count,
521 reason = _reason,
522 "cleaning up pending requests"
523 );
524 }
525
526 requests.clear();
528 pending_requests.store(0, Ordering::Relaxed);
529 }
530}
531
532pub async fn async_block_for_result(
534 rx: &mut mpsc::Receiver<VoltTable>,
535) -> Result<VoltTable, VoltError> {
536 match rx.recv().await {
537 Some(mut table) => match table.has_error() {
538 None => Ok(table),
539 Some(err) => Err(err),
540 },
541 None => Err(VoltError::ConnectionNotAvailable),
542 }
543}
544
545pub async fn async_block_for_result_with_timeout(
547 rx: &mut mpsc::Receiver<VoltTable>,
548 timeout_duration: Duration,
549) -> Result<VoltTable, VoltError> {
550 match timeout(timeout_duration, rx.recv()).await {
551 Ok(Some(mut table)) => match table.has_error() {
552 None => Ok(table),
553 Some(err) => Err(err),
554 },
555 Ok(None) => Err(VoltError::ConnectionNotAvailable),
556 Err(_) => Err(VoltError::Timeout),
557 }
558}
559
560#[cfg(test)]
562mod tests {
563 use super::*;
564
565 #[tokio::test]
566 async fn test_sequence_generation() {
567 let node = AsyncNode {
568 write_tx: mpsc::channel(1).0,
569 info: ConnInfo::default(),
570 requests: Arc::new(DashMap::new()),
571 stop: Arc::new(watch::channel(false).0),
572 counter: Arc::new(AtomicI64::new(1)),
573 pending_requests: Arc::new(AtomicUsize::new(0)),
574 task_handles: std::sync::Mutex::new(Vec::new()),
575 };
576
577 let seq1 = node.get_sequence();
578 let seq2 = node.get_sequence();
579 assert_eq!(seq2, seq1 + 1);
580 }
581
582 #[tokio::test]
583 async fn test_pending_count() {
584 let node = AsyncNode {
585 write_tx: mpsc::channel(1).0,
586 info: ConnInfo::default(),
587 requests: Arc::new(DashMap::new()),
588 stop: Arc::new(watch::channel(false).0),
589 counter: Arc::new(AtomicI64::new(1)),
590 pending_requests: Arc::new(AtomicUsize::new(5)),
591 task_handles: std::sync::Mutex::new(Vec::new()),
592 };
593 assert_eq!(node.pending_count(), 5);
594 }
595
596 #[tokio::test]
597 async fn test_async_block_for_result_with_timeout_expires() {
598 let (_tx, mut rx) = mpsc::channel::<VoltTable>(1);
599 let result = async_block_for_result_with_timeout(&mut rx, Duration::from_millis(50)).await;
600 assert!(matches!(result, Err(VoltError::Timeout)));
601 }
602}