1#![cfg(feature = "tokio")]
2use std::fmt::{Debug, Formatter};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicI64, AtomicUsize, Ordering};
5use std::time::{Duration, Instant};
6
7use crate::encode::{Value, VoltError};
8use crate::node::{ConnInfo, NodeOpt};
9use crate::procedure_invocation::new_procedure_invocation;
10use crate::protocol::{PING_HANDLE, build_auth_message, parse_auth_response};
11use crate::response::VoltResponseInfo;
12use crate::table::{VoltTable, new_volt_table};
13use crate::volt_param;
14use byteorder::{BigEndian, ByteOrder};
15use bytes::{Buf, BytesMut};
16use dashmap::DashMap;
17use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
18use tokio::net::TcpStream;
19use tokio::sync::{mpsc, watch};
20use tokio::time::timeout;
21
22#[cfg(feature = "tracing")]
27#[allow(unused_macros)]
28macro_rules! async_node_trace {
29 ($($arg:tt)*) => { tracing::trace!($($arg)*) };
30}
31#[cfg(not(feature = "tracing"))]
32#[allow(unused_macros)]
33macro_rules! async_node_trace {
34 ($($arg:tt)*) => {};
35}
36
37#[cfg(feature = "tracing")]
38macro_rules! async_node_debug {
39 ($($arg:tt)*) => { tracing::debug!($($arg)*) };
40}
41#[cfg(not(feature = "tracing"))]
42macro_rules! async_node_debug {
43 ($($arg:tt)*) => {};
44}
45
46#[cfg(feature = "tracing")]
47macro_rules! async_node_warn {
48 ($($arg:tt)*) => { tracing::warn!($($arg)*) };
49}
50#[cfg(not(feature = "tracing"))]
51macro_rules! async_node_warn {
52 ($($arg:tt)*) => {};
53}
54
55#[cfg(feature = "tracing")]
56macro_rules! async_node_error {
57 ($($arg:tt)*) => { tracing::error!($($arg)*) };
58}
59#[cfg(not(feature = "tracing"))]
60macro_rules! async_node_error {
61 ($($arg:tt)*) => {};
62}
63
64const MAX_MESSAGE_SIZE: usize = 50 * 1024 * 1024; const WRITE_BUFFER_SIZE: usize = 1024; const BATCH_WRITE_THRESHOLD: usize = 8192; const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30); #[allow(dead_code)]
70const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(60); #[allow(dead_code)]
74enum WriteCommand {
75 Data(Vec<u8>),
76 Flush,
77}
78
79#[allow(dead_code)]
81struct AsyncNetworkRequest {
82 handle: i64,
83 query: bool,
84 sync: bool,
85 num_bytes: i32,
86 channel: mpsc::Sender<VoltTable>,
87 created_at: Instant, }
89
90impl Debug for AsyncNetworkRequest {
91 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
92 f.debug_struct("AsyncNetworkRequest")
93 .field("handle", &self.handle)
94 .field("query", &self.query)
95 .field("age_ms", &self.created_at.elapsed().as_millis())
96 .finish()
97 }
98}
99
100pub struct AsyncNode {
102 write_tx: mpsc::Sender<WriteCommand>,
104 info: ConnInfo,
106 requests: Arc<DashMap<i64, AsyncNetworkRequest>>,
108 stop: Arc<watch::Sender<bool>>,
110 counter: Arc<AtomicI64>,
112 pending_requests: Arc<AtomicUsize>,
114}
115
116impl Debug for AsyncNode {
117 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
118 f.debug_struct("AsyncNode")
119 .field(
120 "pending_requests",
121 &self.pending_requests.load(Ordering::Relaxed),
122 )
123 .field("total_requests", &self.requests.len())
124 .finish()
125 }
126}
127
128impl Drop for AsyncNode {
129 fn drop(&mut self) {
130 let _ = self.stop.send(true);
132 }
133}
134
135impl AsyncNode {
136 pub async fn new(opt: NodeOpt) -> Result<AsyncNode, VoltError> {
138 let addr = format!("{}:{}", opt.ip_port.ip_host, opt.ip_port.port);
139
140 let auth_msg = build_auth_message(opt.user.as_deref(), opt.pass.as_deref())?;
142
143 let mut stream = TcpStream::connect(&addr).await?;
145
146 stream.set_nodelay(true)?; stream.write_all(&auth_msg).await?;
154 stream.flush().await?;
155
156 let mut len_buf = [0u8; 4];
158 stream.read_exact(&mut len_buf).await?;
159 let read = BigEndian::read_u32(&len_buf) as usize;
160
161 let mut all = vec![0; read];
162 stream.read_exact(&mut all).await?;
163
164 let info = parse_auth_response(&all)?;
166
167 let (read_half, write_half) = tokio::io::split(stream);
169
170 let requests = Arc::new(DashMap::new());
172 let (stop_tx, stop_rx) = watch::channel(false);
173 let (write_tx, write_rx) = mpsc::channel(WRITE_BUFFER_SIZE);
174
175 let node = AsyncNode {
176 stop: Arc::new(stop_tx),
177 write_tx,
178 info,
179 requests: requests.clone(),
180 counter: Arc::new(AtomicI64::new(1)),
181 pending_requests: Arc::new(AtomicUsize::new(0)),
182 };
183
184 node.spawn_writer(write_half, write_rx, stop_rx.clone());
186 node.spawn_reader(read_half, stop_rx.clone());
187 node.spawn_timeout_checker(stop_rx);
188
189 Ok(node)
190 }
191
192 #[inline]
194 pub fn get_sequence(&self) -> i64 {
195 self.counter.fetch_add(1, Ordering::Relaxed)
196 }
197
198 #[inline]
200 pub fn pending_count(&self) -> usize {
201 self.pending_requests.load(Ordering::Relaxed)
202 }
203
204 pub fn conn_info(&self) -> &ConnInfo {
206 &self.info
207 }
208
209 pub async fn list_procedures(&self) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
211 self.call_sp("@SystemCatalog", volt_param!("PROCEDURES"))
212 .await
213 }
214
215 pub async fn call_sp(
217 &self,
218 query: &str,
219 param: Vec<&dyn Value>,
220 ) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
221 self.call_sp_with_timeout(query, param, DEFAULT_TIMEOUT)
222 .await
223 }
224
225 pub async fn call_sp_with_timeout(
227 &self,
228 query: &str,
229 param: Vec<&dyn Value>,
230 _timeout_duration: Duration,
231 ) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
232 let req = self.get_sequence();
233 let mut proc = new_procedure_invocation(req, false, ¶m, query);
234
235 let (tx, rx) = mpsc::channel(1);
237
238 let seq = AsyncNetworkRequest {
239 query: true,
240 handle: req,
241 num_bytes: proc.slen,
242 sync: true,
243 channel: tx,
244 created_at: Instant::now(),
245 };
246
247 self.requests.insert(req, seq);
249 self.pending_requests.fetch_add(1, Ordering::Relaxed);
250
251 let bs = proc.bytes();
253 self.write_tx
254 .send(WriteCommand::Data(bs))
255 .await
256 .map_err(|_| VoltError::connection_closed())?;
257
258 Ok(rx)
259 }
260
261 pub async fn upload_jar(&self, bs: Vec<u8>) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
263 self.call_sp("@UpdateClasses", volt_param!(bs, "")).await
264 }
265
266 pub async fn query(&self, sql: &str) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
268 let mut zero_vec: Vec<&dyn Value> = Vec::new();
269 zero_vec.push(&sql);
270 self.call_sp("@AdHoc", zero_vec).await
271 }
272
273 pub async fn ping(&self) -> Result<(), VoltError> {
275 let zero_vec: Vec<&dyn Value> = Vec::new();
276 let mut proc = new_procedure_invocation(PING_HANDLE, false, &zero_vec, "@Ping");
277 let bs = proc.bytes();
278
279 self.write_tx
280 .send(WriteCommand::Data(bs))
281 .await
282 .map_err(|_| VoltError::connection_closed())?;
283
284 Ok(())
285 }
286
287 pub async fn shutdown(&self) -> Result<(), VoltError> {
289 let _ = self.stop.send(true);
290 Ok(())
291 }
292
293 fn spawn_writer(
295 &self,
296 mut write_half: WriteHalf<TcpStream>,
297 mut write_rx: mpsc::Receiver<WriteCommand>,
298 mut stop_rx: watch::Receiver<bool>,
299 ) {
300 tokio::spawn(async move {
301 let mut batch_buffer = Vec::with_capacity(BATCH_WRITE_THRESHOLD * 2);
302
303 loop {
304 tokio::select! {
305 _ = stop_rx.changed() => {
306 if *stop_rx.borrow() {
307 break;
308 }
309 }
310 cmd = write_rx.recv() => {
311 match cmd {
312 Some(WriteCommand::Data(bytes)) => {
313 batch_buffer.extend_from_slice(&bytes);
314
315 while batch_buffer.len() < BATCH_WRITE_THRESHOLD {
317 match write_rx.try_recv() {
318 Ok(WriteCommand::Data(more_bytes)) => {
319 batch_buffer.extend_from_slice(&more_bytes);
320 }
321 Ok(WriteCommand::Flush) => break,
322 Err(_) => break,
323 }
324 }
325
326 if let Err(_e) = write_half.write_all(&batch_buffer).await {
328 async_node_error!(error = %_e, "write error");
329 break;
330 }
331 batch_buffer.clear();
332 }
333 Some(WriteCommand::Flush) => {
334 if !batch_buffer.is_empty() {
335 if let Err(_e) = write_half.write_all(&batch_buffer).await {
336 async_node_error!(error = %_e, "flush error");
337 break;
338 }
339 batch_buffer.clear();
340 }
341 let _ = write_half.flush().await;
342 }
343 None => break,
344 }
345 }
346 }
347 }
348
349 async_node_debug!("writer task terminated");
350 });
351 }
352
353 fn spawn_reader(&self, mut read_half: ReadHalf<TcpStream>, mut stop_rx: watch::Receiver<bool>) {
355 let requests = Arc::clone(&self.requests);
356 let pending_requests = Arc::clone(&self.pending_requests);
357
358 tokio::spawn(async move {
359 let reason = loop {
360 tokio::select! {
361 _ = stop_rx.changed() => {
362 if *stop_rx.borrow() {
363 break "shutdown requested";
364 }
365 }
366 result = Self::async_job(&mut read_half, &requests, &pending_requests) => {
367 if let Err(_e) = result {
368 if !*stop_rx.borrow() {
369 async_node_error!(error = %_e, "read error");
370 }
371 break "connection error";
372 }
373 }
374 }
375 };
376
377 Self::cleanup_requests(&requests, &pending_requests, reason).await;
379 });
380 }
381
382 fn spawn_timeout_checker(&self, mut stop_rx: watch::Receiver<bool>) {
384 let requests = Arc::clone(&self.requests);
385 let pending_requests = Arc::clone(&self.pending_requests);
386
387 tokio::spawn(async move {
388 let mut interval = tokio::time::interval(Duration::from_secs(5));
389
390 loop {
391 tokio::select! {
392 _ = stop_rx.changed() => {
393 if *stop_rx.borrow() {
394 break;
395 }
396 }
397 _ = interval.tick() => {
398 let now = Instant::now();
399 let mut expired = Vec::new();
400
401 for entry in requests.iter() {
403 let age = now.duration_since(entry.created_at);
404 if age > DEFAULT_TIMEOUT * 2 {
405 expired.push(*entry.key());
406 }
407 }
408
409 for handle in expired {
411 if let Some((_, _req)) = requests.remove(&handle) {
412 pending_requests.fetch_sub(1, Ordering::Relaxed);
413 async_node_warn!(
414 handle = handle,
415 elapsed = ?now.duration_since(_req.created_at),
416 "request timed out"
417 );
418 }
420 }
421 }
422 }
423 }
424 });
425 }
426
427 async fn async_job(
429 tcp: &mut ReadHalf<TcpStream>,
430 requests: &Arc<DashMap<i64, AsyncNetworkRequest>>,
431 pending_requests: &Arc<AtomicUsize>,
432 ) -> Result<(), VoltError> {
433 let mut len_buf = [0u8; 4];
435 tcp.read_exact(&mut len_buf).await?;
436 let msg_len = BigEndian::read_u32(&len_buf) as usize;
437
438 if msg_len > MAX_MESSAGE_SIZE {
440 return Err(VoltError::MessageTooLarge(msg_len));
441 }
442
443 if msg_len == 0 {
444 return Ok(());
445 }
446
447 let mut buf = BytesMut::with_capacity(msg_len);
449 buf.resize(msg_len, 0);
450 tcp.read_exact(&mut buf).await?;
451
452 let _ = buf.get_u8();
454 let handle = buf.get_i64();
455
456 if handle == PING_HANDLE {
458 return Ok(());
459 }
460
461 if let Some((_, req)) = requests.remove(&handle) {
463 pending_requests.fetch_sub(1, Ordering::Relaxed);
464
465 let frozen_buf = buf.freeze();
467
468 tokio::spawn(async move {
470 match Self::parse_response(frozen_buf, handle) {
471 Ok(table) => {
472 let _ = req.channel.send(table).await;
473 }
474 Err(_e) => {
475 async_node_error!(handle = handle, error = %_e, "parse error");
476 }
478 }
479 });
480 } else {
481 async_node_warn!(handle = handle, "received response for unknown handle");
482 }
483
484 Ok(())
485 }
486
487 fn parse_response(buf: bytes::Bytes, handle: i64) -> Result<VoltTable, VoltError> {
489 let mut byte_buf = bytebuffer::ByteBuffer::from_bytes(&buf[..]);
491 let info = VoltResponseInfo::new(&mut byte_buf, handle)?;
492 let table = new_volt_table(&mut byte_buf, info)?;
493 Ok(table)
494 }
495
496 async fn cleanup_requests(
498 requests: &Arc<DashMap<i64, AsyncNetworkRequest>>,
499 pending_requests: &Arc<AtomicUsize>,
500 _reason: &str,
501 ) {
502 let pending_count = requests.len();
503
504 if pending_count > 0 {
505 async_node_warn!(
506 pending_count = pending_count,
507 reason = _reason,
508 "cleaning up pending requests"
509 );
510 }
511
512 requests.clear();
514 pending_requests.store(0, Ordering::Relaxed);
515 }
516}
517
518pub async fn async_block_for_result(
520 rx: &mut mpsc::Receiver<VoltTable>,
521) -> Result<VoltTable, VoltError> {
522 match rx.recv().await {
523 Some(mut table) => match table.has_error() {
524 None => Ok(table),
525 Some(err) => Err(err),
526 },
527 None => Err(VoltError::ConnectionNotAvailable),
528 }
529}
530
531pub async fn async_block_for_result_with_timeout(
533 rx: &mut mpsc::Receiver<VoltTable>,
534 timeout_duration: Duration,
535) -> Result<VoltTable, VoltError> {
536 match timeout(timeout_duration, rx.recv()).await {
537 Ok(Some(mut table)) => match table.has_error() {
538 None => Ok(table),
539 Some(err) => Err(err),
540 },
541 Ok(None) => Err(VoltError::ConnectionNotAvailable),
542 Err(_) => Err(VoltError::Timeout),
543 }
544}
545
546impl VoltError {
548 pub fn message_too_large(size: usize) -> Self {
549 VoltError::MessageTooLarge(size)
550 }
551
552 pub fn connection_closed() -> Self {
553 VoltError::ConnectionClosed
554 }
555
556 pub fn timeout() -> Self {
557 VoltError::Timeout
558 }
559}
560
561#[cfg(test)]
563mod tests {
564 use super::*;
565
566 #[tokio::test]
567 async fn test_sequence_generation() {
568 let node = AsyncNode {
569 write_tx: mpsc::channel(1).0,
570 info: ConnInfo::default(),
571 requests: Arc::new(DashMap::new()),
572 stop: Arc::new(watch::channel(false).0),
573 counter: Arc::new(AtomicI64::new(1)),
574 pending_requests: Arc::new(AtomicUsize::new(0)),
575 };
576
577 let seq1 = node.get_sequence();
578 let seq2 = node.get_sequence();
579 assert_eq!(seq2, seq1 + 1);
580 }
581
582 #[tokio::test]
583 async fn test_pending_count() {
584 let node = AsyncNode {
585 write_tx: mpsc::channel(1).0,
586 info: ConnInfo::default(),
587 requests: Arc::new(DashMap::new()),
588 stop: Arc::new(watch::channel(false).0),
589 counter: Arc::new(AtomicI64::new(1)),
590 pending_requests: Arc::new(AtomicUsize::new(5)),
591 };
592 assert_eq!(node.pending_count(), 5);
593 }
594}